Update README.md
Browse files
README.md
CHANGED
|
@@ -63,7 +63,7 @@ from transformers import AutoTokenizer, AutoModel
|
|
| 63 |
tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-code")
|
| 64 |
model = AutoModel.from_pretrained("nomic-ai/nomic-embed-code")
|
| 65 |
|
| 66 |
-
def last_token_pooling(
|
| 67 |
sequence_lengths = attention_mask.sum(-1) - 1
|
| 68 |
return hidden_states[torch.arange(hidden_states.shape[0]), sequence_lengths]
|
| 69 |
|
|
@@ -74,7 +74,8 @@ code_snippets = queries + codes
|
|
| 74 |
encoded_input = tokenizer(code_snippets, padding=True, truncation=True, return_tensors='pt')
|
| 75 |
model.eval()
|
| 76 |
with torch.no_grad():
|
| 77 |
-
model_output = model(**encoded_input)
|
|
|
|
| 78 |
embeddings = last_token_pooling(model_output, encoded_input['attention_mask'])
|
| 79 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 80 |
print(embeddings.shape)
|
|
@@ -95,7 +96,7 @@ model = SentenceTransformer("nomic-ai/nomic-embed-code")
|
|
| 95 |
query_emb = model.encode(queries, prompt_name="query")
|
| 96 |
code_emb = model.encode(code_snippets)
|
| 97 |
|
| 98 |
-
similarity = model.similarity(query_emb, code_emb)
|
| 99 |
print(similarity)
|
| 100 |
```
|
| 101 |
|
|
|
|
| 63 |
tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-code")
|
| 64 |
model = AutoModel.from_pretrained("nomic-ai/nomic-embed-code")
|
| 65 |
|
| 66 |
+
def last_token_pooling(hidden_states, attention_mask):
|
| 67 |
sequence_lengths = attention_mask.sum(-1) - 1
|
| 68 |
return hidden_states[torch.arange(hidden_states.shape[0]), sequence_lengths]
|
| 69 |
|
|
|
|
| 74 |
encoded_input = tokenizer(code_snippets, padding=True, truncation=True, return_tensors='pt')
|
| 75 |
model.eval()
|
| 76 |
with torch.no_grad():
|
| 77 |
+
model_output = model(**encoded_input)[0]
|
| 78 |
+
|
| 79 |
embeddings = last_token_pooling(model_output, encoded_input['attention_mask'])
|
| 80 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 81 |
print(embeddings.shape)
|
|
|
|
| 96 |
query_emb = model.encode(queries, prompt_name="query")
|
| 97 |
code_emb = model.encode(code_snippets)
|
| 98 |
|
| 99 |
+
similarity = model.similarity(query_emb[0], code_emb[0])
|
| 100 |
print(similarity)
|
| 101 |
```
|
| 102 |
|