Upload ConstBERT
Browse files- colbert_configuration.py +3 -0
- modeling.py +12 -9
colbert_configuration.py
CHANGED
|
@@ -158,6 +158,7 @@ class ResourceSettings:
|
|
| 158 |
collection: str = DefaultVal(None)
|
| 159 |
queries: str = DefaultVal(None)
|
| 160 |
index_name: str = DefaultVal(None)
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
@dataclass
|
|
@@ -350,6 +351,7 @@ class BaseConfig(CoreConfig):
|
|
| 350 |
|
| 351 |
return config
|
| 352 |
|
|
|
|
| 353 |
try:
|
| 354 |
checkpoint_path = hf_hub_download(
|
| 355 |
repo_id=checkpoint_path, filename="artifact.metadata"
|
|
@@ -360,6 +362,7 @@ class BaseConfig(CoreConfig):
|
|
| 360 |
if os.path.exists(loaded_config_path):
|
| 361 |
loaded_config, _ = cls.from_path(loaded_config_path)
|
| 362 |
loaded_config.set("checkpoint", checkpoint_path)
|
|
|
|
| 363 |
|
| 364 |
return loaded_config
|
| 365 |
|
|
|
|
| 158 |
collection: str = DefaultVal(None)
|
| 159 |
queries: str = DefaultVal(None)
|
| 160 |
index_name: str = DefaultVal(None)
|
| 161 |
+
name_or_path: str = DefaultVal(None)
|
| 162 |
|
| 163 |
|
| 164 |
@dataclass
|
|
|
|
| 351 |
|
| 352 |
return config
|
| 353 |
|
| 354 |
+
name_or_path = checkpoint_path
|
| 355 |
try:
|
| 356 |
checkpoint_path = hf_hub_download(
|
| 357 |
repo_id=checkpoint_path, filename="artifact.metadata"
|
|
|
|
| 362 |
if os.path.exists(loaded_config_path):
|
| 363 |
loaded_config, _ = cls.from_path(loaded_config_path)
|
| 364 |
loaded_config.set("checkpoint", checkpoint_path)
|
| 365 |
+
loaded_config.set("name_or_path", name_or_path)
|
| 366 |
|
| 367 |
return loaded_config
|
| 368 |
|
modeling.py
CHANGED
|
@@ -1,18 +1,11 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
|
|
|
|
| 3 |
import torch
|
| 4 |
from tqdm import tqdm
|
| 5 |
-
from transformers import AutoTokenizer
|
| 6 |
from .colbert_configuration import ColBERTConfig
|
| 7 |
from .tokenization_utils import QueryTokenizer, DocTokenizer
|
| 8 |
-
|
| 9 |
-
# this is a hack to force huggingface hub to download the tokenizer files
|
| 10 |
-
try:
|
| 11 |
-
with open("./tokenizer_config.json", "r") as f, open("./tokenizer.json", "r") as f2, open("./vocab.txt", "r") as f3:
|
| 12 |
-
pass
|
| 13 |
-
except Exception as e:
|
| 14 |
-
pass
|
| 15 |
-
|
| 16 |
class NullContextManager(object):
|
| 17 |
def __init__(self, dummy_resource=None):
|
| 18 |
self.dummy_resource = dummy_resource
|
|
@@ -70,6 +63,16 @@ class ConstBERT(BertPreTrainedModel):
|
|
| 70 |
self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
|
| 71 |
self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
|
| 74 |
self.doc_tokenizer = DocTokenizer(colbert_config)
|
| 75 |
self.amp_manager = MixedPrecisionManager(True)
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
import torch
|
| 5 |
from tqdm import tqdm
|
|
|
|
| 6 |
from .colbert_configuration import ColBERTConfig
|
| 7 |
from .tokenization_utils import QueryTokenizer, DocTokenizer
|
| 8 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
class NullContextManager(object):
|
| 10 |
def __init__(self, dummy_resource=None):
|
| 11 |
self.dummy_resource = dummy_resource
|
|
|
|
| 63 |
self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
|
| 64 |
self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
|
| 65 |
|
| 66 |
+
## Download required tokenizer files from Hugging Face
|
| 67 |
+
if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer.json")):
|
| 68 |
+
hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer.json")
|
| 69 |
+
if not os.path.exists(os.path.join(colbert_config.name_or_path, "vocab.txt")):
|
| 70 |
+
hf_hub_download(repo_id=colbert_config.name_or_path, filename="vocab.txt")
|
| 71 |
+
if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer_config.json")):
|
| 72 |
+
hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer_config.json")
|
| 73 |
+
if not os.path.exists(os.path.join(colbert_config.name_or_path, "special_tokens_map.json")):
|
| 74 |
+
hf_hub_download(repo_id=colbert_config.name_or_path, filename="special_tokens_map.json")
|
| 75 |
+
|
| 76 |
self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
|
| 77 |
self.doc_tokenizer = DocTokenizer(colbert_config)
|
| 78 |
self.amp_manager = MixedPrecisionManager(True)
|