Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,85 @@
|
|
| 1 |
---
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
language:
|
| 3 |
+
- ru
|
| 4 |
+
tags:
|
| 5 |
+
- RAG
|
| 6 |
+
- cross-encoder
|
| 7 |
+
pipeline_tag: sentence-similarity
|
| 8 |
---
|
| 9 |
+
# Overview
|
| 10 |
+
Cross-encoder for russian language. Primarily trained for **RAG** purposes.
|
| 11 |
+
Take two strings, assess if they are related (question and answer pair).
|
| 12 |
+
|
| 13 |
+
# Usage
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
import torch
|
| 17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 18 |
+
|
| 19 |
+
!wget https://huggingface.co/GrigoryT22/cross-encoder-ru/resolve/main/model.pt # or simply load the file via browser
|
| 20 |
+
|
| 21 |
+
model = Model() # copy-past class code (see below) and run it
|
| 22 |
+
model.load_state_dict(torch.load('./model.pt'), strict=False) # path to downloaded file with the model
|
| 23 |
+
# missing_keys=['labse.embeddings.position_ids'] - this is [OK](https://github.com/huggingface/transformers/issues/16353)
|
| 24 |
+
|
| 25 |
+
string_1 = """
|
| 26 |
+
Компания судится с артистом
|
| 27 |
+
""".strip()
|
| 28 |
+
|
| 29 |
+
string_2 = """
|
| 30 |
+
По заявлению инвесторов, компания знала о рисках заключения подобного контракта задолго до антисемитских высказываний Уэста,
|
| 31 |
+
которые он озвучил в октябре 2022 года. Однако, несмотря на то, что Adidas прекратил сотрудничество с артистом,
|
| 32 |
+
избежать судебного разбирательства не удалось. После расторжения контракта с рэпером компания потеряет 1,3 миллиарда долларов.
|
| 33 |
+
""".strip()
|
| 34 |
+
|
| 35 |
+
model([
|
| 36 |
+
[string_1, string_2]
|
| 37 |
+
])
|
| 38 |
+
# should be something like this --->>> tensor([[-4.0403, 3.8442]], grad_fn=<AddmmBackward0>)
|
| 39 |
+
# model is pretty sure that these two strings are related, second number is bigger (logits for binary classifications, batch size one in this case)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Model class
|
| 44 |
+
```python
|
| 45 |
+
class Model(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
labse - base bert-like model
|
| 48 |
+
from labse I use pooler layer as input
|
| 49 |
+
then classification head - binary classification to predict if this pair is TRUE question-answer
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.labse_config = AutoConfig.from_pretrained('cointegrated/LaBSE-en-ru')
|
| 54 |
+
self.labse = AutoModel.from_config(self.labse_config)
|
| 55 |
+
self.tokenizer = AutoTokenizer.from_pretrained('cointegrated/LaBSE-en-ru')
|
| 56 |
+
self.cls = nn.Sequential(OrderedDict(
|
| 57 |
+
[
|
| 58 |
+
('dropout_in', torch.nn.Dropout(.0)),
|
| 59 |
+
('layernorm_in' , nn.LayerNorm(768, eps=1e-05)),
|
| 60 |
+
|
| 61 |
+
('fc_1' , nn.Linear(768, 768 * 2)),
|
| 62 |
+
('act_1' , nn.GELU()),
|
| 63 |
+
('layernorm_1' , nn.LayerNorm(768 * 2, eps=1e-05)),
|
| 64 |
+
|
| 65 |
+
('fc_2' , nn.Linear(768 * 2, 768 * 2)),
|
| 66 |
+
('act_2' , nn.GELU()),
|
| 67 |
+
('layernorm_2' , nn.LayerNorm(768 * 2, eps=1e-05)),
|
| 68 |
+
|
| 69 |
+
('fc_3' , nn.Linear(768 * 2, 768)),
|
| 70 |
+
('act_3' , nn.GELU()),
|
| 71 |
+
('layernorm_3' , nn.LayerNorm(768, eps=1e-05)),
|
| 72 |
+
|
| 73 |
+
('fc_4' , nn.Linear(768, 256)),
|
| 74 |
+
('act_4' , nn.GELU()),
|
| 75 |
+
('layernorm_4' , nn.LayerNorm(256, eps=1e-05)),
|
| 76 |
+
|
| 77 |
+
('fc_5' , nn.Linear(256, 2, bias=True)),
|
| 78 |
+
]
|
| 79 |
+
))
|
| 80 |
+
def forward(self, text):
|
| 81 |
+
token = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
|
| 82 |
+
model_output = self.labse(**token)
|
| 83 |
+
result = self.cls(model_output.pooler_output)
|
| 84 |
+
return result
|
| 85 |
+
```
|