|
r""" |
|
- basic bpe-tokenizer that doesn't uses byte pairing, insted uses set of initial unique characters |
|
to train the new vocab |
|
- set of initial characters = ["\n", "A", "C", "G", "T", " "] that can be present in a file or are |
|
needed for the tokenizer |
|
- save and load functions, saves two files, '.model' and 'vocab.json' and only '.model' file is loaded |
|
'vocab.json' is just for human interpretation |
|
""" |
|
|
|
from tqdm import tqdm |
|
import json |
|
import os |
|
current_dir = os.path.dirname(os.path.realpath(__file__)) |
|
os.chdir(current_dir) |
|
|
|
class DNAtokenizer: |
|
def __init__(self): |
|
""" |
|
inital variables: |
|
- chars = set of unique characters that could be present in the file, that are needed |
|
- merges, vocab = empty dictonaries to store future merges and final vocab |
|
- vocab_size = initially it's equal to 6 or len(chars), updated later |
|
- str_to_idx, idx_to_str = functions enumerate chars to idx and idx to chars |
|
""" |
|
super().__init__() |
|
self.chars = ["\n", "A", "C", "G", "T", " "] |
|
self.vocab_size = len(self.chars) |
|
self.merges = {} |
|
self.vocab = {} |
|
self.string_to_index = {char: idx for idx, char in enumerate(self.chars)} |
|
self.index_to_string = {idx: char for idx, char in enumerate(self.chars)} |
|
|
|
def _encode(self, string): |
|
""" |
|
encoder: takes a string, returns a list of integers |
|
eg. AATGC --> ['2', '2', '5', '4', '3'] |
|
""" |
|
encoded = [self.string_to_index[char] for char in string] |
|
return encoded |
|
|
|
def _decode(self, integer): |
|
""" |
|
decoder: takes a list of integers, returns a string |
|
eg. ['2', '2', '5', '4', '3'] --> AATGC |
|
""" |
|
decoded = ''.join([self.index_to_string[i] for i in integer]) |
|
return decoded |
|
|
|
def _get_stats(self, ids, counts=None): |
|
""" |
|
takes list of integers and returns dictionary of counts of pairs(consecutive ones) |
|
eg: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1} |
|
allows to update an existing dictionary of counts |
|
""" |
|
counts = {} if counts is None else counts |
|
for pair in zip(ids, ids[1:]): |
|
counts[pair] = counts.get(pair, 0) + 1 |
|
return counts |
|
|
|
def _merge(self, ids, pair, idx): |
|
""" |
|
in the list of integers, replaces all consecutive pair with the new integer token idx |
|
eg: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] |
|
""" |
|
new_ids = [] |
|
i = 0 |
|
while i < len(ids): |
|
if i+1 < len(ids) and ids[i] == pair[0] and ids[i+1] == pair[1]: |
|
new_ids.append(idx) |
|
i += 2 |
|
else: |
|
new_ids.append(ids[i]) |
|
i += 1 |
|
return new_ids |
|
|
|
def _build_vocab(self): |
|
""" |
|
it was causing some bugs, if not used, so I had to use it |
|
""" |
|
return {i: ids for i, ids in enumerate(self.chars)} |
|
|
|
def train(self, train_data, target_vocab): |
|
""" |
|
- takes in the data, encodes it using _encode() function, converts each unique char to index |
|
eg. AATGC --> ['2', '2', '5', '4', '3'] |
|
- performs iteration till n_merges i.e. target_vocab - self.vocab_size |
|
- each iteration, makes dictonary of 2 consecutive pairs and then merges the max occuring |
|
pair together |
|
- at the end uses merges to build final vocab |
|
|
|
Args: |
|
train_data (str): a big file containing lots of dna sequence |
|
target_vocab (integer): name tells you fucking idiot |
|
""" |
|
vocab = self._build_vocab() |
|
tokens = self._encode(train_data) |
|
ids = list(tokens) |
|
|
|
merges = {} |
|
n_merges = target_vocab - self.vocab_size + 1 |
|
for i in tqdm(range(n_merges), desc='Training the tokenizer\t'): |
|
stats = self._get_stats(ids) |
|
pair = max(stats, key=stats.get) |
|
idx = self.vocab_size + i |
|
ids = self._merge(ids, pair, idx) |
|
merges[pair] = idx |
|
|
|
for (p0, p1), idx in merges.items(): |
|
vocab[idx] = vocab[p0] + vocab[p1] |
|
|
|
self.vocab = vocab |
|
self.merges = merges |
|
self.vocab_size = len(vocab) |
|
|
|
def continue_train(self, train_data, n_merges): |
|
""" |
|
- takes in the data, performs iteration till n_merges |
|
- continues from the last index of the loaded merges |
|
- each iteration, makes dictonary of 2 consecutive pairs and then merges the max occuring |
|
pair together (same as train()) |
|
- at the end uses merges to build final vocab |
|
|
|
Args: |
|
train_data (str): a big file containing lots of dna sequence |
|
n_merges (integer): no of merges |
|
|
|
** this function has some problems |
|
""" |
|
tokens = self._encode(train_data) |
|
ids = list(tokens) |
|
for i in tqdm(range(n_merges), desc='Training continue'): |
|
stats = self._get_stats(ids) |
|
pair = max(stats, key=stats.get) |
|
idx = self.vocab_size + i |
|
ids = self._merge(ids, pair, idx) |
|
self.merges[pair] = idx |
|
|
|
for (p0, p1), idx in self.merges.items(): |
|
self.vocab[idx] = self.vocab[p0] + self.vocab[p1] |
|
|
|
self.vocab_size = len(self.vocab) |
|
|
|
def encode(self, text): |
|
""" |
|
- takes in the input string, encodes it using initial vocab '_encode()' function |
|
- fetches merges from saved or loaded merges |
|
|
|
Args: |
|
train_data (str): string of dna sequence |
|
self.merges (dictonary): contains merges |
|
""" |
|
tokens = self._encode(text) |
|
ids = list(tokens) |
|
while len(ids) >= 2: |
|
stats = self._get_stats(ids) |
|
pair = min(stats, key=lambda p: self.merges.get(p, float('inf'))) |
|
if pair not in self.merges: |
|
break |
|
|
|
idx = self.merges[pair] |
|
ids = self._merge(ids, pair, idx) |
|
return ids |
|
|
|
def decode(self, de_text): |
|
tokens = [self.vocab[idx] for idx in de_text] |
|
text = ''.join(tokens) |
|
return text |
|
|
|
def save_model(self, model_prefix): |
|
""" |
|
- basic save_model() funtion, saves two files, '.model' & 'vocab.json' |
|
- '.model' contians all the final merges, each on next line |
|
- 'vocab.json' contians the final vocab, for human interpretation |
|
|
|
Args: |
|
model_prefix (str): prefix along with the path |
|
self.merges (dict): contains final merges |
|
self.vocab (dict): contains final vocab |
|
""" |
|
model_file = model_prefix + '.model' |
|
|
|
with open(model_file, 'w', encoding='utf-8') as fwrite: |
|
for ids1, ids2 in self.merges: |
|
fwrite.write(f"{ids1} {ids2}\n") |
|
vocab_file = model_prefix + '_vocab.json' |
|
with open(vocab_file, 'w') as f: |
|
json.dump(self.vocab, f) |
|
print('model file saved successfully!') |
|
|
|
def load_model(self, model_path): |
|
""" |
|
- loads the '.model' file |
|
- re-writes the merges in the new merges dict |
|
- builds the vocab again for further use |
|
|
|
Args: |
|
model_path (str): path to the '.model' file |
|
""" |
|
assert model_path.endswith('.model') |
|
|
|
merges = {} |
|
idx = self.vocab_size |
|
with open(model_path, 'r', encoding='utf-8') as fread: |
|
for line in fread: |
|
idx1, idx2 = map(int, line.split()) |
|
merges[(idx1, idx2)] = idx |
|
idx += 1 |
|
vocab = self._build_vocab() |
|
|
|
for (p0, p1), idx in merges.items(): |
|
vocab[idx] = vocab[p0] + vocab[p1] |
|
|
|
self.merges = merges |
|
self.vocab = vocab |
|
self.vocab_size = len(self.vocab) |