File size: 11,054 Bytes
03562ce 9d3ebbc 03562ce a6be4c2 9d3ebbc 39157e5 03562ce 624a902 03562ce 624a902 03562ce 39157e5 03562ce 2336bf5 03562ce 9d3ebbc 03562ce 96f6622 03562ce cb5f0b7 03562ce 09fb7a6 03562ce 634cac7 2336bf5 634cac7 2336bf5 634cac7 03562ce 2336bf5 03562ce 2336bf5 03562ce 09fb7a6 03562ce 47d8577 5219c44 03562ce 47d8577 5219c44 03562ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
from tqdm import tqdm
from .colbert_configuration import ColBERTConfig
from .tokenization_utils import QueryTokenizer, DocTokenizer
import os
class NullContextManager(object):
def __init__(self, dummy_resource=None):
self.dummy_resource = dummy_resource
def __enter__(self):
return self.dummy_resource
def __exit__(self, *args):
pass
class MixedPrecisionManager():
def __init__(self, activated):
self.activated = activated
if self.activated:
self.scaler = torch.amp.GradScaler("cuda")
def context(self):
return torch.amp.autocast("cuda") if self.activated else NullContextManager()
def backward(self, loss):
if self.activated:
self.scaler.scale(loss).backward()
else:
loss.backward()
def step(self, colbert, optimizer, scheduler=None):
if self.activated:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False)
self.scaler.step(optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
class ConstBERT(BertPreTrainedModel):
"""
Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level.
This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly.
"""
_keys_to_ignore_on_load_unexpected = [r"cls"]
def __init__(self, config, colbert_config, verbose:int = 0):
super().__init__(config)
self.config = config
self.colbert_config = colbert_config
self.dim = colbert_config.dim
self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False)
self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
## Download required tokenizer files from Hugging Face
if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer.json")):
hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer.json")
if not os.path.exists(os.path.join(colbert_config.name_or_path, "vocab.txt")):
hf_hub_download(repo_id=colbert_config.name_or_path, filename="vocab.txt")
if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer_config.json")):
hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer_config.json")
if not os.path.exists(os.path.join(colbert_config.name_or_path, "special_tokens_map.json")):
hf_hub_download(repo_id=colbert_config.name_or_path, filename="special_tokens_map.json")
self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
self.doc_tokenizer = DocTokenizer(colbert_config)
self.amp_manager = MixedPrecisionManager(True)
self.raw_tokenizer = AutoTokenizer.from_pretrained(colbert_config.checkpoint)
self.pad_token = self.raw_tokenizer.pad_token_id
self.use_gpu = colbert_config.total_visible_gpus > 0
setattr(self,self.base_model_prefix, BertModel(config))
# if colbert_config.relu:
# self.score_scaler = nn.Linear(1, 1)
self.init_weights()
# if colbert_config.relu:
# self.score_scaler.weight.data.fill_(1.0)
# self.score_scaler.bias.data.fill_(-8.0)
@property
def LM(self):
base_model_prefix = getattr(self, "base_model_prefix")
return getattr(self, base_model_prefix)
@classmethod
def from_pretrained(cls, name_or_path, config=None, *args, **kwargs):
colbert_config = ColBERTConfig(name_or_path)
colbert_config = ColBERTConfig.from_existing(ColBERTConfig.load_from_checkpoint(name_or_path), colbert_config)
obj = super().from_pretrained(name_or_path, colbert_config=colbert_config, config=config)
obj.base = name_or_path
return obj
@staticmethod
def raw_tokenizer_from_pretrained(name_or_path):
obj = AutoTokenizer.from_pretrained(name_or_path)
obj.base = name_or_path
return obj
def _query(self, input_ids, attention_mask):
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
Q = self.bert(input_ids, attention_mask=attention_mask)[0]
# Q = Q.permute(0, 2, 1) #(64, 128,32)
# Q = self.query_project(Q) #(64, 128,8)
# Q = Q.permute(0, 2, 1) #(64,8,128)
Q = self.linear(Q)
mask = self.mask(input_ids, skiplist=[]).unsqueeze(2)
Q = Q * mask
return torch.nn.functional.normalize(Q, p=2, dim=2)
def forward(self, input_ids, attention_mask):
"""
Forward method for ONNX export and PyTorch compatibility.
This will now call _doc to produce a fixed number of vectors.
"""
return self._doc(input_ids, attention_mask)
def _doc(self, input_ids, attention_mask, keep_dims=True):
assert keep_dims in [True, False, 'return_mask']
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
D = self.bert(input_ids, attention_mask=attention_mask)[0] # Shape: (batch_size, seq_len, hidden_size)
# First, apply linear layer to project hidden_size to colbert_config.dim (128)
D = self.linear(D) # Shape: (batch_size, seq_len, dim)
# Now, permute to put seq_len in the feature dimension for doc_project
D = D.permute(0, 2, 1) # Shape: (batch_size, dim, seq_len)
# Apply doc_project to reduce seq_len (e.g., 250) to fixed length (32)
# The nn.Linear(in_features, out_features) operates on the last dimension.
# So it expects the last dimension to be seq_len (doc_maxlen).
# It will transform it to (batch_size, dim, 32)
D = self.doc_project(D) # Shape: (batch_size, dim, 32)
# Permute back to (batch_size, 32, dim)
D = D.permute(0, 2, 1) # Shape: (batch_size, 32, dim)
# Apply mask (assuming it's still needed in this part of the flow)
# The mask now needs to be applied correctly to the (batch_size, 32, dim) shape
# For now, let's simplify mask application or ensure it's handled correctly if it remains a static shape.
# Given the fixed output, the original masking might be less critical here, or needs to be re-evaluated.
# Temporarily removing original mask logic in _doc to avoid immediate conflict.
# If a learned mask is needed on the 32 vectors, it needs separate logic.
# mask = torch.ones(D.shape[0], D.shape[1], device=self.device).unsqueeze(2).float()
# D = D * mask
D = torch.nn.functional.normalize(D, p=2, dim=2)
if self.use_gpu:
D = D.half()
# Removed keep_dims conditional branches as _doc now consistently returns fixed 32 vectors.
return D
def mask(self, input_ids, skiplist):
# For ONNX export and inference, skiplist should be empty
# Create mask: 1 where input_ids != pad_token, else 0
return (input_ids != self.pad_token).float()
def query(self, *args, to_cpu=False, **kw_args):
with torch.no_grad():
with self.amp_manager.context():
Q = self._query(*args, **kw_args)
return Q.cpu() if to_cpu else Q
def doc(self, *args, to_cpu=False, **kw_args):
with torch.no_grad():
with self.amp_manager.context():
D = self._doc(*args, **kw_args)
if to_cpu:
return (D[0].cpu(), *D[1:]) if isinstance(D, tuple) else D.cpu()
return D
def encode_queries(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
if type(queries) == str:
queries = [queries]
if bsize:
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
return torch.cat(batches)
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
return self.query(input_ids, attention_mask)
def encode_documents(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
if type(docs) == str:
docs = [docs]
assert keep_dims in [True, False, 'flatten']
if bsize:
text_batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)
returned_text = []
if return_tokens:
returned_text = [text for batch in text_batches for text in batch[0]]
returned_text = [returned_text[idx] for idx in reverse_indices.tolist()]
returned_text = [returned_text]
keep_dims_ = 'return_mask' if keep_dims == 'flatten' else keep_dims
batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims_, to_cpu=to_cpu)
for input_ids, attention_mask in tqdm(text_batches, disable=not showprogress)]
if keep_dims is True:
D = _stack_3D_tensors(batches)
return (D[reverse_indices], *returned_text)
elif keep_dims == 'flatten':
D, mask = [], []
for D_, mask_ in batches:
D.append(D_)
mask.append(mask_)
D, mask = torch.cat(D)[reverse_indices], torch.cat(mask)[reverse_indices]
doclens = mask.squeeze(-1).sum(-1).tolist()
D = D.view(-1, self.colbert_config.dim)
D = D[mask.bool().flatten()].cpu()
return (D, doclens, *returned_text)
assert keep_dims is False
D = [d for batch in batches for d in batch]
return ([D[idx] for idx in reverse_indices.tolist()], *returned_text)
input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
return self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
def _stack_3D_tensors(groups):
bsize = sum([x.size(0) for x in groups])
maxlen = max([x.size(1) for x in groups])
hdim = groups[0].size(2)
output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)
offset = 0
for x in groups:
endpos = offset + x.size(0)
output[offset:endpos, :x.size(1)] = x
offset = endpos
return output |