|
--- |
|
license: apache-2.0 |
|
base_model: |
|
- Qwen/Qwen3-Embedding-8B |
|
pipeline_tag: sentence-similarity |
|
--- |
|
|
|
The model of SitEmb-v1.5-Qwen3 trained with additional book notes and their corresponding underlined texts. |
|
|
|
### Transformer Usage |
|
```python |
|
import torch |
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
from tqdm import tqdm |
|
from more_itertools import chunked |
|
|
|
|
|
residual = True |
|
residual_factor = 0.5 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"Qwen/Qwen3-Embedding-8B", |
|
use_fast=True, |
|
padding_side='left', |
|
) |
|
|
|
model = AutoModel.from_pretrained( |
|
"SituatedEmbedding/SitEmb-v1.5-Qwen3-note", |
|
torch_dtype=torch.bfloat16, |
|
device_map={"": 0}, |
|
) |
|
|
|
def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None): |
|
if pooling in ['cls', 'first']: |
|
reps = last_hidden_state[:, 0] |
|
elif pooling in ['mean', 'avg', 'average']: |
|
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
elif pooling in ['last', 'eos']: |
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
|
if left_padding: |
|
reps = last_hidden_state[:, -1] |
|
else: |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
batch_size = last_hidden_state.shape[0] |
|
reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] |
|
elif pooling == 'ext': |
|
if match_idx is None: |
|
# default mean |
|
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
else: |
|
for k in range(input_ids.shape[0]): |
|
sep_index = input_ids[k].tolist().index(match_idx) |
|
attention_mask[k][sep_index:] = 0 |
|
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
else: |
|
raise ValueError(f'unknown pooling method: {pooling}') |
|
if normalize: |
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
return reps |
|
|
|
|
|
def first_eos_token_pooling( |
|
last_hidden_states, |
|
first_eos_position, |
|
normalize, |
|
): |
|
batch_size = last_hidden_states.shape[0] |
|
reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position] |
|
if normalize: |
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
return reps |
|
|
|
def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual): |
|
task = "Given a search query, retrieve relevant chunks from fictions that answer the query" |
|
sents = [] |
|
for query in queries: |
|
sents.append(get_detailed_instruct(task, query)) |
|
|
|
return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length) |
|
|
|
|
|
def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False): |
|
pas_embs = [] |
|
pas_embs_residual = [] |
|
total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0) |
|
with tqdm(total=total) as pbar: |
|
for sent_b in chunked(passages, batch_size): |
|
batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, |
|
return_tensors='pt').to(model.device) |
|
if residual: |
|
batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, ) |
|
input_ids = batch_list_dict['input_ids'] |
|
attention_mask = batch_list_dict['attention_mask'] |
|
max_len = len(input_ids[0]) |
|
input_starts = [max_len - sum(att) for att in attention_mask] |
|
eos_pos = [] |
|
for ii, it in zip(input_ids, input_starts): |
|
pos = ii.index(tokenizer.pad_token_id, it) |
|
eos_pos.append(pos) |
|
eos_pos = torch.tensor(eos_pos).to(model.device) |
|
else: |
|
eos_pos = None |
|
outputs = model(**batch_dict) |
|
pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize) |
|
if residual: |
|
remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize) |
|
pas_embs_residual.append(remb_) |
|
pas_embs.append(pemb_) |
|
pbar.update(1) |
|
pas_embs = torch.cat(pas_embs, dim=0) |
|
if pas_embs_residual: |
|
pas_embs_residual = torch.cat(pas_embs_residual, dim=0) |
|
else: |
|
pas_embs_residual = None |
|
return pas_embs, pas_embs_residual |
|
|
|
your_query = "Your Query" |
|
|
|
query_hidden, _ = encode_query( |
|
tokenizer, model, pooling_type="eos", queries=[your_query], |
|
batch_size=8, normalize=True, max_length=8192, residual=residual, |
|
) |
|
|
|
passage_affix = "The context in which the chunk is situated is given below. Please encode the chunk by being aware of the context. Context:\n" |
|
your_chunk = "Your Chunk" |
|
your_context = "Your Context" |
|
|
|
candidate_hidden, candidate_hidden_residual = encode_passage( |
|
tokenizer, model, pooling_type="eos", passages=[f"{your_chunk}<|endoftext|>{passage_affix}{your_context}"], |
|
batch_size=4, normalize=True, max_length=8192, residual=residual, |
|
) |
|
|
|
query2candidate = query_hidden @ candidate_hidden.T # [num_queries, num_candidates] |
|
if candidate_hidden_residual is not None: |
|
query2candidate_residual = query_hidden @ candidate_hidden_residual.T |
|
if residual_factor == 1.: |
|
query2candidate = query2candidate_residual |
|
elif residual_factor == 0.: |
|
pass |
|
else: |
|
query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor |
|
|
|
print(query2candidate.tolist()) |
|
``` |