edit source
Browse files
model.py
CHANGED
|
@@ -1,31 +1,7 @@
|
|
| 1 |
-
###################################################################################################
|
| 2 |
-
###################################################################################################
|
| 3 |
-
###################################################################################################
|
| 4 |
-
|
| 5 |
-
import collections
|
| 6 |
-
import logging
|
| 7 |
-
|
| 8 |
-
import json
|
| 9 |
-
import math
|
| 10 |
-
import os
|
| 11 |
-
import re
|
| 12 |
-
from collections import OrderedDict
|
| 13 |
-
from functools import partial
|
| 14 |
-
from typing import List, Optional, Tuple, Union
|
| 15 |
-
|
| 16 |
-
import torch
|
| 17 |
-
import torch.nn as nn
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
########################################################
|
| 22 |
-
########################################################
|
| 23 |
-
########################################################
|
| 24 |
-
########################################################
|
| 25 |
-
|
| 26 |
-
|
| 27 |
from typing import Callable, Optional, Tuple
|
|
|
|
| 28 |
import copy
|
|
|
|
| 29 |
import math
|
| 30 |
import multiprocessing
|
| 31 |
import os
|
|
@@ -34,7 +10,6 @@ import torch
|
|
| 34 |
import torch.nn as nn
|
| 35 |
import transformers
|
| 36 |
|
| 37 |
-
|
| 38 |
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
| 39 |
"""We create a dummy configuration class that will just set properties
|
| 40 |
based on whatever kwargs we pass in.
|
|
@@ -54,14 +29,13 @@ class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
|
| 54 |
continue
|
| 55 |
super().__init__()
|
| 56 |
|
| 57 |
-
|
| 58 |
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
| 59 |
transformers.PreTrainedModel,
|
| 60 |
transformers.PreTrainedTokenizer
|
| 61 |
]:
|
| 62 |
-
|
| 63 |
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
| 64 |
-
model =
|
| 65 |
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 66 |
elif name in ["gtr-base", "gtr_base"]:
|
| 67 |
model = transformers.AutoModel.from_pretrained(
|
|
@@ -106,8 +80,6 @@ def load_embedder_and_tokenizer(name: str) -> Tuple[
|
|
| 106 |
# from optimum.bettertransformer import BetterTransformer
|
| 107 |
# model = BetterTransformer.transform(model)
|
| 108 |
return model, tokenizer
|
| 109 |
-
|
| 110 |
-
|
| 111 |
def get_world_size() -> int:
|
| 112 |
try:
|
| 113 |
return torch.distributed.get_world_size()
|
|
@@ -318,7 +290,7 @@ def maxsim(
|
|
| 318 |
sub_x = slice_tensor_rows(X, start, end)
|
| 319 |
if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}")
|
| 320 |
if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape)
|
| 321 |
-
sub_sim = sub_x @ y # TODO –
|
| 322 |
sub_sim = sub_sim
|
| 323 |
if maximize:
|
| 324 |
sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1)
|
|
@@ -471,7 +443,6 @@ def disable_causality(model: torch.nn.Module):
|
|
| 471 |
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
|
| 472 |
)
|
| 473 |
|
| 474 |
-
|
| 475 |
class ContextualModelMixin(nn.Module):
|
| 476 |
@property
|
| 477 |
def num_corpus_tokens(self) -> int:
|
|
@@ -511,9 +482,6 @@ class ContextualModelMixin(nn.Module):
|
|
| 511 |
# Auto-expand for a batch.
|
| 512 |
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
|
| 513 |
dataset_embeddings = dataset_embeddings.to(input_ids.device)
|
| 514 |
-
|
| 515 |
-
if len(dataset_embeddings.shape) < 3:
|
| 516 |
-
raise ValueError(f"dataset_embeddings must have at least 3 dimensions, got {dataset_embeddings.shape}")
|
| 517 |
|
| 518 |
batch_size = input_ids.shape[0]
|
| 519 |
if (self.transductive_tokens_per_document > 1):
|
|
@@ -532,11 +500,9 @@ class ContextualModelMixin(nn.Module):
|
|
| 532 |
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
|
| 533 |
else:
|
| 534 |
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
|
|
|
|
| 535 |
|
| 536 |
-
|
| 537 |
-
if dataset_embeddings.shape[1] < self.num_corpus_tokens:
|
| 538 |
-
raise ValueError(f"dataset_embeddings must have at least {self.num_corpus_tokens} tokens, got {dataset_embeddings.shape[1]}")
|
| 539 |
-
elif dataset_embeddings.shape[1] > self.num_corpus_tokens:
|
| 540 |
# If too many dataset embeddings are passed in, just take the first N until
|
| 541 |
# we have the proper number.
|
| 542 |
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
|
|
@@ -558,6 +524,8 @@ class ContextualModelMixin(nn.Module):
|
|
| 558 |
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
| 559 |
dataset_embeddings = null_embeddings
|
| 560 |
|
|
|
|
|
|
|
| 561 |
# backbone_max_seq_length = self.backbone.config.max_trained_positions
|
| 562 |
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
|
| 563 |
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
|
|
@@ -630,8 +598,15 @@ class BiEncoder(transformers.PreTrainedModel):
|
|
| 630 |
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
|
| 631 |
for a corpus with three documents and two hard negatives per document
|
| 632 |
"""
|
|
|
|
|
|
|
| 633 |
del token_type_ids
|
| 634 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
outputs = (
|
| 636 |
self.embedder(
|
| 637 |
input_ids=input_ids,
|
|
@@ -801,7 +776,6 @@ class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualM
|
|
| 801 |
return output
|
| 802 |
|
| 803 |
|
| 804 |
-
|
| 805 |
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
| 806 |
def __init__(
|
| 807 |
self,
|
|
@@ -812,14 +786,12 @@ class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelM
|
|
| 812 |
self.backbone = dataset_backbone
|
| 813 |
self.hidden_size = self.backbone.config.hidden_size
|
| 814 |
self.hidden_size = dataset_backbone.config.hidden_size
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
self.contextual_init()
|
| 816 |
self._shift_rotary_embedding()
|
| 817 |
-
|
| 818 |
-
self.pool_ignore_contextual_tokens = vars(self.config).get("pool_ignore_contextual_tokens", True)
|
| 819 |
-
self.pool_ignore_instruction_tokens = vars(self.config).get("pool_ignore_instruction_tokens", False)
|
| 820 |
-
|
| 821 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(self.config.embedder)
|
| 822 |
-
self.pool_instruction_end_id = tokenizer.encode(": ", add_special_tokens=False)[0] # Hardcoded for colon-ending prefixes.
|
| 823 |
|
| 824 |
@property
|
| 825 |
def num_corpus_tokens(self) -> int:
|
|
@@ -848,55 +820,48 @@ class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelM
|
|
| 848 |
output_hidden_states: bool = False,
|
| 849 |
null_dataset_embedding: bool = False,
|
| 850 |
) -> torch.Tensor:
|
|
|
|
| 851 |
soft_prompt = self._prepare_dataset_embeddings(
|
| 852 |
input_ids=input_ids,
|
| 853 |
dataset_embeddings=dataset_embeddings,
|
| 854 |
null_dataset_embedding=null_dataset_embedding,
|
| 855 |
)
|
|
|
|
| 856 |
backbone_attention_mask = torch.ones(
|
| 857 |
soft_prompt.shape[0:2],
|
| 858 |
dtype=torch.long,
|
| 859 |
device=soft_prompt.device,
|
| 860 |
)
|
| 861 |
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
|
|
|
|
| 862 |
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
| 863 |
-
|
|
|
|
|
|
|
| 864 |
output = self.backbone(
|
| 865 |
inputs_embeds=inputs_embeds,
|
| 866 |
-
attention_mask=
|
| 867 |
) # (1, 4 + b + s, d)
|
| 868 |
# trim soft prompt
|
| 869 |
output_vectors = output.last_hidden_state
|
| 870 |
|
| 871 |
# use only these tokens
|
| 872 |
n_soft_prompt_tokens = soft_prompt.shape[1]
|
|
|
|
| 873 |
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
# This is a bit arcane but relies on the fact that there will be a BOS token after the
|
| 877 |
-
# instruction, but also there may or may not be a BOS token at the beginning.
|
| 878 |
-
instruction_end_idx = (
|
| 879 |
-
(input_ids == self.pool_instruction_end_id) &
|
| 880 |
-
attention_mask &
|
| 881 |
-
(torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] > 0)
|
| 882 |
-
).int().argmax(1)
|
| 883 |
-
is_instruction_token_mask = (
|
| 884 |
-
torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] <= instruction_end_idx[:, None]
|
| 885 |
-
)
|
| 886 |
-
# catch edge case where there is no instruction
|
| 887 |
-
is_instruction_token_mask = is_instruction_token_mask.where(
|
| 888 |
-
(instruction_end_idx > 0)[:, None], torch.zeros_like(is_instruction_token_mask)
|
| 889 |
-
)
|
| 890 |
-
output_attention_mask = torch.cat((backbone_attention_mask, attention_mask & ~is_instruction_token_mask), dim=1)
|
| 891 |
-
else:
|
| 892 |
-
output_attention_mask = input_attention_mask
|
| 893 |
|
| 894 |
-
|
| 895 |
-
output_vectors = output_vectors[:, n_soft_prompt_tokens:, :]
|
| 896 |
-
output_attention_mask = output_attention_mask[:, n_soft_prompt_tokens:]
|
| 897 |
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
|
|
|
| 898 |
# average with original vectors
|
| 899 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
|
| 901 |
if output_hidden_states:
|
| 902 |
return {
|
|
@@ -967,7 +932,7 @@ class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel):
|
|
| 967 |
):
|
| 968 |
super().__init__(config=config)
|
| 969 |
dataset_backbone, _ = load_embedder_and_tokenizer(
|
| 970 |
-
vars(config).get("dataset_backbone"
|
| 971 |
)
|
| 972 |
|
| 973 |
if config.limit_layers:
|
|
@@ -1026,6 +991,8 @@ class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel):
|
|
| 1026 |
output_hidden_states=output_hidden_states,
|
| 1027 |
)
|
| 1028 |
|
|
|
|
|
|
|
| 1029 |
def get_model_class(name: str):
|
| 1030 |
if name in 'transductive':
|
| 1031 |
return ContextualDocumentEmbeddingTransformer
|
|
@@ -1034,4 +1001,4 @@ def get_model_class(name: str):
|
|
| 1034 |
elif name == "dataset_prefix_biencoder":
|
| 1035 |
return DatasetPrefixBiencoder
|
| 1036 |
else:
|
| 1037 |
-
raise ValueError(f'unknown model cls {name}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Callable, Optional, Tuple
|
| 2 |
+
|
| 3 |
import copy
|
| 4 |
+
import json
|
| 5 |
import math
|
| 6 |
import multiprocessing
|
| 7 |
import os
|
|
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import transformers
|
| 12 |
|
|
|
|
| 13 |
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
| 14 |
"""We create a dummy configuration class that will just set properties
|
| 15 |
based on whatever kwargs we pass in.
|
|
|
|
| 29 |
continue
|
| 30 |
super().__init__()
|
| 31 |
|
|
|
|
| 32 |
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
| 33 |
transformers.PreTrainedModel,
|
| 34 |
transformers.PreTrainedTokenizer
|
| 35 |
]:
|
| 36 |
+
assert name is not None, "name must be provided to load_embedder_and_tokenizer"
|
| 37 |
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
| 38 |
+
model = transformers.AutoModelForMaskedLM.from_pretrained(name, trust_remote_code=True).bert
|
| 39 |
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 40 |
elif name in ["gtr-base", "gtr_base"]:
|
| 41 |
model = transformers.AutoModel.from_pretrained(
|
|
|
|
| 80 |
# from optimum.bettertransformer import BetterTransformer
|
| 81 |
# model = BetterTransformer.transform(model)
|
| 82 |
return model, tokenizer
|
|
|
|
|
|
|
| 83 |
def get_world_size() -> int:
|
| 84 |
try:
|
| 85 |
return torch.distributed.get_world_size()
|
|
|
|
| 290 |
sub_x = slice_tensor_rows(X, start, end)
|
| 291 |
if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}")
|
| 292 |
if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape)
|
| 293 |
+
sub_sim = sub_x @ y # TODO – Implement sparse max here to save mem!
|
| 294 |
sub_sim = sub_sim
|
| 295 |
if maximize:
|
| 296 |
sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1)
|
|
|
|
| 443 |
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
|
| 444 |
)
|
| 445 |
|
|
|
|
| 446 |
class ContextualModelMixin(nn.Module):
|
| 447 |
@property
|
| 448 |
def num_corpus_tokens(self) -> int:
|
|
|
|
| 482 |
# Auto-expand for a batch.
|
| 483 |
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
|
| 484 |
dataset_embeddings = dataset_embeddings.to(input_ids.device)
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
batch_size = input_ids.shape[0]
|
| 487 |
if (self.transductive_tokens_per_document > 1):
|
|
|
|
| 500 |
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
|
| 501 |
else:
|
| 502 |
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
|
| 503 |
+
# print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
|
| 504 |
|
| 505 |
+
if dataset_embeddings.shape[1] > self.num_corpus_tokens:
|
|
|
|
|
|
|
|
|
|
| 506 |
# If too many dataset embeddings are passed in, just take the first N until
|
| 507 |
# we have the proper number.
|
| 508 |
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
|
|
|
|
| 524 |
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
| 525 |
dataset_embeddings = null_embeddings
|
| 526 |
|
| 527 |
+
# print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}")
|
| 528 |
+
|
| 529 |
# backbone_max_seq_length = self.backbone.config.max_trained_positions
|
| 530 |
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
|
| 531 |
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
|
|
|
|
| 598 |
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
|
| 599 |
for a corpus with three documents and two hard negatives per document
|
| 600 |
"""
|
| 601 |
+
# del dataset_input_ids
|
| 602 |
+
# del dataset_attention_mask
|
| 603 |
del token_type_ids
|
| 604 |
|
| 605 |
+
# from cde.lib.dist import get_rank
|
| 606 |
+
# tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 607 |
+
# if get_rank() == 0:
|
| 608 |
+
# breakpoint()
|
| 609 |
+
# torch.distributed.barrier()
|
| 610 |
outputs = (
|
| 611 |
self.embedder(
|
| 612 |
input_ids=input_ids,
|
|
|
|
| 776 |
return output
|
| 777 |
|
| 778 |
|
|
|
|
| 779 |
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
| 780 |
def __init__(
|
| 781 |
self,
|
|
|
|
| 786 |
self.backbone = dataset_backbone
|
| 787 |
self.hidden_size = self.backbone.config.hidden_size
|
| 788 |
self.hidden_size = dataset_backbone.config.hidden_size
|
| 789 |
+
# self.input_ln = torch.nn.LayerNorm(
|
| 790 |
+
# self.hidden_size,
|
| 791 |
+
# eps=self.backbone.config.layer_norm_epsilon
|
| 792 |
+
# )
|
| 793 |
self.contextual_init()
|
| 794 |
self._shift_rotary_embedding()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
|
| 796 |
@property
|
| 797 |
def num_corpus_tokens(self) -> int:
|
|
|
|
| 820 |
output_hidden_states: bool = False,
|
| 821 |
null_dataset_embedding: bool = False,
|
| 822 |
) -> torch.Tensor:
|
| 823 |
+
# print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape)
|
| 824 |
soft_prompt = self._prepare_dataset_embeddings(
|
| 825 |
input_ids=input_ids,
|
| 826 |
dataset_embeddings=dataset_embeddings,
|
| 827 |
null_dataset_embedding=null_dataset_embedding,
|
| 828 |
)
|
| 829 |
+
# print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}")
|
| 830 |
backbone_attention_mask = torch.ones(
|
| 831 |
soft_prompt.shape[0:2],
|
| 832 |
dtype=torch.long,
|
| 833 |
device=soft_prompt.device,
|
| 834 |
)
|
| 835 |
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
|
| 836 |
+
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
|
| 837 |
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
| 838 |
+
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
|
| 839 |
+
attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
| 840 |
+
# print("[3.b] attention_mask.shape =", attention_mask.shape)
|
| 841 |
output = self.backbone(
|
| 842 |
inputs_embeds=inputs_embeds,
|
| 843 |
+
attention_mask=attention_mask,
|
| 844 |
) # (1, 4 + b + s, d)
|
| 845 |
# trim soft prompt
|
| 846 |
output_vectors = output.last_hidden_state
|
| 847 |
|
| 848 |
# use only these tokens
|
| 849 |
n_soft_prompt_tokens = soft_prompt.shape[1]
|
| 850 |
+
# print("n_soft_prompt_tokens =", n_soft_prompt_tokens)
|
| 851 |
|
| 852 |
+
output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :]
|
| 853 |
+
output_attention_mask = attention_mask[:, n_soft_prompt_tokens:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
|
| 855 |
+
# print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape)
|
|
|
|
|
|
|
| 856 |
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
| 857 |
+
|
| 858 |
# average with original vectors
|
| 859 |
+
# TODO: Argparse for pooling strategy.
|
| 860 |
+
# output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d)
|
| 861 |
+
# print("output_pooled.shape =", output_pooled.shape)
|
| 862 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
| 863 |
+
|
| 864 |
+
# print("returning output.shape =", output.shape)
|
| 865 |
|
| 866 |
if output_hidden_states:
|
| 867 |
return {
|
|
|
|
| 932 |
):
|
| 933 |
super().__init__(config=config)
|
| 934 |
dataset_backbone, _ = load_embedder_and_tokenizer(
|
| 935 |
+
vars(config).get("dataset_backbone") or config.embedder
|
| 936 |
)
|
| 937 |
|
| 938 |
if config.limit_layers:
|
|
|
|
| 991 |
output_hidden_states=output_hidden_states,
|
| 992 |
)
|
| 993 |
|
| 994 |
+
|
| 995 |
+
|
| 996 |
def get_model_class(name: str):
|
| 997 |
if name in 'transductive':
|
| 998 |
return ContextualDocumentEmbeddingTransformer
|
|
|
|
| 1001 |
elif name == "dataset_prefix_biencoder":
|
| 1002 |
return DatasetPrefixBiencoder
|
| 1003 |
else:
|
| 1004 |
+
raise ValueError(f'unknown model cls {name}')
|