BPE tokenizer implementation
What is tokenization and why we need it?
tokenization is the proess of encoding original raw text into a shorter string of representation.
For example, for word “the” its binary represnetation is “[116,104, 101]” which requires 3 numbers to store. With tokenziation we can merge these three bytes into one number as “256”.
So basically, tokenization is the process of compression of the input text. Why compress it? Because for transformers, input context length is limited and we want to put as much as information in the limited context window to process as much as information in one time. So tokenization helps us to train model to process longers input text.
Byte-pair-encoding(BPE) tokenization algorithm
The basic ideas of bpe algorithm is to identify occuring adjacent pair of tokens and merge them by assigning a new token number for this pair of tokenso
For example, suppose we have training text the cat in the hat
, in the text, “th” appear twice, so at the beinging of the algorithm we replace the text like this <256>e cat in <256>e hat
. So we replace two words with one number aka. the text is shorten and compressed.
The vocabulary is
0: ...
...
256: "th"
And we can repeat this process by merging <256>e
ito <257>
because <256>e
occurs twice in the text. The vocabulary is
0: ...
...
256: "th"
257: "<256>e"
So from steps above we can see that more compressed the text is , the larger vocabulary size we get because we have more new tokens in the vocabulary.
Steps
-
Identify frequent pairs
-
Replace and record
-
Repeat until no gains.
code.
- Identify requent pairs ```cpp def get_stats(tokens): pair_count = {} for pair in zip(tokens, tokens[1:]): pair_count[pair] = pair_count.get(pair, 0) + 1
return pair_count
stats = get_stats(tokens) print(get_stats(tokens)) print(“most occuring pair”, max(stats, key=stats.get), ‘occuring count’, stats[max(stats, key=stats.get)]) top_pair = max(stats, key=stats.get)
2. Replace and record ( until reach maximum vocab size )
```python
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
# ---
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list
merges = {} # (int, int) -> int
for i in range(num_merges):
stats = get_stats(ids)
pair = max(stats, key=stats.get)
idx = 256 + i
print(f"merging {pair} into a new token {idx}")
ids = merge(ids, pair, idx)
merges[pair] = idx
Encoding
Once we finishing training our tokenizer we can use that tokenizer to tokenize any new input texts
def encode(text):
tokens = list(text.encode("utf-8"))
while len(tokens) >= 2:
stats = get_stats(tokens)
pair = min(stats, lambda p: merges.get(p, float('inf')))
if pair not in merges:
break
idx = merges[pair]
tokens = merge(tokens, pair, idx)
return tokens
print(decode(encode("hello world")))
Using the min
function with the custom key helps select the most suitable pair of tokens for merging based on defined criteria. Here’s an example to illustrate its purpose:
Imagine you have a list of token pairs with their merge priorities stored in a dictionary merges
. You want to find the pair with the lowest priority value that exists in this dictionary.
Consider the following example:
merges = {('th', 'e'): 1, ('a', 'n'): 2, ('i', 'n'): 3, ('s', 't'): 4}
stats = {('th', 'e'): 5, ('a', 'n'): 3, ('i', 'n'): 2, ('h', 'e'): 4}
In this case, the merges
dictionary contains pairs with their priorities, and the stats
dictionary contains pairs with their frequencies. You want to find the pair with the lowest priority that exists in the merges
dictionary.
The code pair = min(stats, key=lambda p: merges.get(p, float("inf")))
will evaluate as follows:
- For each pair in
stats
:- (‘th’, ‘e’):
merges.get(('th', 'e'), float("inf"))
returns1
- (‘a’, ‘n’):
merges.get(('a', 'n'), float("inf"))
returns2
- (‘i’, ‘n’):
merges.get(('i', 'n'), float("inf"))
returns3
- (‘h’, ‘e’):
merges.get(('h', 'e'), float("inf"))
returnsfloat("inf")
(not inmerges
)
- (‘th’, ‘e’):
- The
min
function selects the pair with the smallest value:-
('th', 'e')
with a priority of1
-
Therefore, pair
will be ('th', 'e')
, the pair with the lowest priority for merging. This approach ensures that only pairs defined in the merges
dictionary are considered and prioritizes them based on their defined values. This makes the merging process efficient and controlled according to specific criteria.
Decoding
Since each new token(id >= 255) generated is represented by two sub tokens we can just concat the sub tokens to build original new tokens.
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
def decode(ids):
# given ids (list of integers), return Python string
tokens = b"".join(vocab[idx] for idx in ids)
text = tokens.decode("utf-8", errors="replace")
return text
print(decode([128]))
References
tokenization notebook from karpathy
Enjoy Reading This Article?
Here are some more articles you might like to read next: