DeepMostInnovations commited on
Commit
19da03f
·
verified ·
1 Parent(s): ae91f30

Add inference script

Browse files
Files changed (1) hide show
  1. hindi_embeddings.py +724 -0
hindi_embeddings.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import numpy as np
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ import sentencepiece as spm
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ from tqdm import tqdm
10
+ import matplotlib.pyplot as plt
11
+ from sklearn.manifold import TSNE
12
+
13
+ # Tokenizer wrapper class
14
+ class SentencePieceTokenizerWrapper:
15
+ def __init__(self, sp_model_path):
16
+ self.sp_model = spm.SentencePieceProcessor()
17
+ self.sp_model.Load(sp_model_path)
18
+ self.vocab_size = self.sp_model.GetPieceSize()
19
+
20
+ # Special token IDs from tokenizer training
21
+ self.pad_token_id = 0
22
+ self.bos_token_id = 1
23
+ self.eos_token_id = 2
24
+ self.unk_token_id = 3
25
+
26
+ # Set special tokens
27
+ self.pad_token = "<pad>"
28
+ self.bos_token = "<s>"
29
+ self.eos_token = "</s>"
30
+ self.unk_token = "<unk>"
31
+ self.mask_token = "<mask>"
32
+
33
+ def __call__(self, text, padding=False, truncation=False, max_length=None, return_tensors=None):
34
+ # Handle both string and list inputs
35
+ if isinstance(text, str):
36
+ # Encode a single string
37
+ ids = self.sp_model.EncodeAsIds(text)
38
+
39
+ # Handle truncation
40
+ if truncation and max_length and len(ids) > max_length:
41
+ ids = ids[:max_length]
42
+
43
+ attention_mask = [1] * len(ids)
44
+
45
+ # Handle padding
46
+ if padding and max_length:
47
+ padding_length = max(0, max_length - len(ids))
48
+ ids = ids + [self.pad_token_id] * padding_length
49
+ attention_mask = attention_mask + [0] * padding_length
50
+
51
+ result = {
52
+ 'input_ids': ids,
53
+ 'attention_mask': attention_mask
54
+ }
55
+
56
+ # Convert to tensors if requested
57
+ if return_tensors == 'pt':
58
+ import torch
59
+ result = {k: torch.tensor([v]) for k, v in result.items()}
60
+
61
+ return result
62
+
63
+ # Process a batch of texts
64
+ batch_encoded = [self.sp_model.EncodeAsIds(t) for t in text]
65
+
66
+ # Apply truncation if needed
67
+ if truncation and max_length:
68
+ batch_encoded = [ids[:max_length] for ids in batch_encoded]
69
+
70
+ # Create attention masks
71
+ batch_attention_mask = [[1] * len(ids) for ids in batch_encoded]
72
+
73
+ # Apply padding if needed
74
+ if padding:
75
+ if max_length:
76
+ max_len = max_length
77
+ else:
78
+ max_len = max(len(ids) for ids in batch_encoded)
79
+
80
+ # Pad sequences to max_len
81
+ batch_encoded = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in batch_encoded]
82
+ batch_attention_mask = [mask + [0] * (max_len - len(mask)) for mask in batch_attention_mask]
83
+
84
+ result = {
85
+ 'input_ids': batch_encoded,
86
+ 'attention_mask': batch_attention_mask
87
+ }
88
+
89
+ # Convert to tensors if requested
90
+ if return_tensors == 'pt':
91
+ import torch
92
+ result = {k: torch.tensor(v) for k, v in result.items()}
93
+
94
+ return result
95
+
96
+ # Model architecture components
97
+ class MultiHeadAttention(nn.Module):
98
+ """Multi-headed attention mechanism"""
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.num_attention_heads = config["num_attention_heads"]
102
+ self.attention_head_size = config["hidden_size"] // config["num_attention_heads"]
103
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
104
+
105
+ # Query, Key, Value projections
106
+ self.query = nn.Linear(config["hidden_size"], self.all_head_size)
107
+ self.key = nn.Linear(config["hidden_size"], self.all_head_size)
108
+ self.value = nn.Linear(config["hidden_size"], self.all_head_size)
109
+
110
+ # Output projection
111
+ self.output = nn.Sequential(
112
+ nn.Linear(self.all_head_size, config["hidden_size"]),
113
+ nn.Dropout(config["attention_probs_dropout_prob"])
114
+ )
115
+
116
+ # Simplified relative position bias
117
+ self.max_position_embeddings = config["max_position_embeddings"]
118
+ self.relative_attention_bias = nn.Embedding(
119
+ 2 * config["max_position_embeddings"] - 1,
120
+ config["num_attention_heads"]
121
+ )
122
+
123
+ def transpose_for_scores(self, x):
124
+ new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
125
+ x = x.view(*new_shape)
126
+ return x.permute(0, 2, 1, 3)
127
+
128
+ def forward(self, hidden_states, attention_mask=None):
129
+ batch_size, seq_length = hidden_states.size()[:2]
130
+
131
+ # Project inputs to queries, keys, and values
132
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
133
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
134
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
135
+
136
+ # Take the dot product between query and key to get the raw attention scores
137
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
138
+
139
+ # Generate relative position matrix
140
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device)
141
+ relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0) # [seq_len, seq_len]
142
+ # Shift values to be >= 0
143
+ relative_position = relative_position + self.max_position_embeddings - 1
144
+ # Ensure indices are within bounds
145
+ relative_position = torch.clamp(relative_position, 0, 2 * self.max_position_embeddings - 2)
146
+
147
+ # Get relative position embeddings [seq_len, seq_len, num_heads]
148
+ rel_attn_bias = self.relative_attention_bias(relative_position) # [seq_len, seq_len, num_heads]
149
+
150
+ # Reshape to add to attention heads [1, num_heads, seq_len, seq_len]
151
+ rel_attn_bias = rel_attn_bias.permute(2, 0, 1).unsqueeze(0)
152
+
153
+ # Add to attention scores - now dimensions will match
154
+ attention_scores = attention_scores + rel_attn_bias
155
+
156
+ # Scale attention scores
157
+ attention_scores = attention_scores / (self.attention_head_size ** 0.5)
158
+
159
+ # Apply attention mask
160
+ if attention_mask is not None:
161
+ attention_scores = attention_scores + attention_mask
162
+
163
+ # Normalize the attention scores to probabilities
164
+ attention_probs = F.softmax(attention_scores, dim=-1)
165
+
166
+ # Apply dropout
167
+ attention_probs = F.dropout(attention_probs, p=0.1, training=self.training)
168
+
169
+ # Apply attention to values
170
+ context_layer = torch.matmul(attention_probs, value_layer)
171
+
172
+ # Reshape back to [batch_size, seq_length, hidden_size]
173
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
174
+ new_shape = context_layer.size()[:-2] + (self.all_head_size,)
175
+ context_layer = context_layer.view(*new_shape)
176
+
177
+ # Final output projection
178
+ output = self.output(context_layer)
179
+
180
+ return output
181
+
182
+ class EnhancedTransformerLayer(nn.Module):
183
+ """Advanced transformer layer with pre-layer norm and enhanced attention"""
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.attention_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
187
+ self.attention = MultiHeadAttention(config)
188
+
189
+ self.ffn_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
190
+
191
+ # Feed-forward network
192
+ self.ffn = nn.Sequential(
193
+ nn.Linear(config["hidden_size"], config["intermediate_size"]),
194
+ nn.GELU(),
195
+ nn.Dropout(config["hidden_dropout_prob"]),
196
+ nn.Linear(config["intermediate_size"], config["hidden_size"]),
197
+ nn.Dropout(config["hidden_dropout_prob"])
198
+ )
199
+
200
+ def forward(self, hidden_states, attention_mask=None):
201
+ # Pre-layer norm for attention
202
+ attn_norm_hidden = self.attention_pre_norm(hidden_states)
203
+
204
+ # Self-attention
205
+ attention_output = self.attention(attn_norm_hidden, attention_mask)
206
+
207
+ # Residual connection
208
+ hidden_states = hidden_states + attention_output
209
+
210
+ # Pre-layer norm for feed-forward
211
+ ffn_norm_hidden = self.ffn_pre_norm(hidden_states)
212
+
213
+ # Feed-forward
214
+ ffn_output = self.ffn(ffn_norm_hidden)
215
+
216
+ # Residual connection
217
+ hidden_states = hidden_states + ffn_output
218
+
219
+ return hidden_states
220
+
221
+ class AdvancedTransformerModel(nn.Module):
222
+ """Advanced Transformer model for inference"""
223
+
224
+ def __init__(self, config):
225
+ super().__init__()
226
+ self.config = config
227
+
228
+ # Embeddings
229
+ self.word_embeddings = nn.Embedding(
230
+ config["vocab_size"],
231
+ config["hidden_size"],
232
+ padding_idx=config["pad_token_id"]
233
+ )
234
+
235
+ # Position embeddings
236
+ self.position_embeddings = nn.Embedding(config["max_position_embeddings"], config["hidden_size"])
237
+
238
+ # Embedding dropout
239
+ self.embedding_dropout = nn.Dropout(config["hidden_dropout_prob"])
240
+
241
+ # Transformer layers
242
+ self.layers = nn.ModuleList([
243
+ EnhancedTransformerLayer(config) for _ in range(config["num_hidden_layers"])
244
+ ])
245
+
246
+ # Final layer norm
247
+ self.final_layer_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
248
+
249
+ def forward(self, input_ids, attention_mask=None):
250
+ input_shape = input_ids.size()
251
+ batch_size, seq_length = input_shape
252
+
253
+ # Get position ids
254
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
255
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
256
+
257
+ # Get embeddings
258
+ word_embeds = self.word_embeddings(input_ids)
259
+ position_embeds = self.position_embeddings(position_ids)
260
+
261
+ # Sum embeddings
262
+ embeddings = word_embeds + position_embeds
263
+
264
+ # Apply dropout
265
+ embeddings = self.embedding_dropout(embeddings)
266
+
267
+ # Default attention mask
268
+ if attention_mask is None:
269
+ attention_mask = torch.ones(input_shape, device=input_ids.device)
270
+
271
+ # Extended attention mask for transformer layers (1 for tokens to attend to, 0 for masked tokens)
272
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
273
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
274
+
275
+ # Apply transformer layers
276
+ hidden_states = embeddings
277
+ for layer in self.layers:
278
+ hidden_states = layer(hidden_states, extended_attention_mask)
279
+
280
+ # Final layer norm
281
+ hidden_states = self.final_layer_norm(hidden_states)
282
+
283
+ return hidden_states
284
+
285
+ class AdvancedPooling(nn.Module):
286
+ """Advanced pooling module supporting multiple pooling strategies"""
287
+ def __init__(self, config):
288
+ super().__init__()
289
+ self.pooling_mode = config["pooling_mode"] # 'mean', 'max', 'cls', 'attention'
290
+ self.hidden_size = config["hidden_size"]
291
+
292
+ # For attention pooling
293
+ if self.pooling_mode == 'attention':
294
+ self.attention_weights = nn.Linear(config["hidden_size"], 1)
295
+
296
+ # For weighted pooling
297
+ elif self.pooling_mode == 'weighted':
298
+ self.weight_layer = nn.Linear(config["hidden_size"], 1)
299
+
300
+ def forward(self, token_embeddings, attention_mask=None):
301
+ if attention_mask is None:
302
+ attention_mask = torch.ones_like(token_embeddings[:, :, 0])
303
+
304
+ mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
305
+
306
+ if self.pooling_mode == 'cls':
307
+ # Use [CLS] token (first token)
308
+ pooled = token_embeddings[:, 0]
309
+
310
+ elif self.pooling_mode == 'max':
311
+ # Max pooling
312
+ token_embeddings = token_embeddings.clone()
313
+ # Set padding tokens to large negative value to exclude them from max
314
+ token_embeddings[mask_expanded == 0] = -1e9
315
+ pooled = torch.max(token_embeddings, dim=1)[0]
316
+
317
+ elif self.pooling_mode == 'attention':
318
+ # Attention pooling
319
+ weights = self.attention_weights(token_embeddings).squeeze(-1)
320
+ # Mask out padding tokens
321
+ weights = weights.masked_fill(attention_mask == 0, -1e9)
322
+ weights = F.softmax(weights, dim=1).unsqueeze(-1)
323
+ pooled = torch.sum(token_embeddings * weights, dim=1)
324
+
325
+ elif self.pooling_mode == 'weighted':
326
+ # Weighted average pooling
327
+ weights = torch.sigmoid(self.weight_layer(token_embeddings)).squeeze(-1)
328
+ # Apply mask
329
+ weights = weights * attention_mask
330
+ # Normalize weights
331
+ sum_weights = torch.sum(weights, dim=1, keepdim=True)
332
+ sum_weights = torch.clamp(sum_weights, min=1e-9)
333
+ weights = weights / sum_weights
334
+ # Apply weights
335
+ pooled = torch.sum(token_embeddings * weights.unsqueeze(-1), dim=1)
336
+
337
+ else: # Default to mean pooling
338
+ # Mean pooling
339
+ sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
340
+ sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
341
+ pooled = sum_embeddings / sum_mask
342
+
343
+ # L2 normalize
344
+ pooled = F.normalize(pooled, p=2, dim=1)
345
+
346
+ return pooled
347
+
348
+ class SentenceEmbeddingModel(nn.Module):
349
+ """Complete sentence embedding model for inference"""
350
+ def __init__(self, config):
351
+ super(SentenceEmbeddingModel, self).__init__()
352
+ self.config = config
353
+
354
+ # Create transformer model
355
+ self.transformer = AdvancedTransformerModel(config)
356
+
357
+ # Create pooling module
358
+ self.pooling = AdvancedPooling(config)
359
+
360
+ # Build projection module if needed
361
+ if "projection_dim" in config and config["projection_dim"] > 0:
362
+ self.use_projection = True
363
+ self.projection = nn.Sequential(
364
+ nn.Linear(config["hidden_size"], config["hidden_size"]),
365
+ nn.GELU(),
366
+ nn.Linear(config["hidden_size"], config["projection_dim"]),
367
+ nn.LayerNorm(config["projection_dim"], eps=config["layer_norm_eps"])
368
+ )
369
+ else:
370
+ self.use_projection = False
371
+
372
+ def forward(self, input_ids, attention_mask=None):
373
+ # Get token embeddings from transformer
374
+ token_embeddings = self.transformer(input_ids, attention_mask)
375
+
376
+ # Pool token embeddings
377
+ pooled_output = self.pooling(token_embeddings, attention_mask)
378
+
379
+ # Apply projection if enabled
380
+ if self.use_projection:
381
+ pooled_output = self.projection(pooled_output)
382
+ pooled_output = F.normalize(pooled_output, p=2, dim=1)
383
+
384
+ return pooled_output
385
+
386
+ class HindiEmbedder:
387
+ def __init__(self, model_path="/home/ubuntu/output/hindi-embeddings-custom-tokenizer/final", tokenizer_path=None):
388
+ """
389
+ Initialize the Hindi sentence embedder.
390
+
391
+ Args:
392
+ model_path: Path to the model directory
393
+ tokenizer_path: Optional path to tokenizer. If None, will look in the model directory.
394
+ """
395
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
396
+ print(f"Using device: {self.device}")
397
+
398
+ # Load tokenizer
399
+ if tokenizer_path is None:
400
+ # Try standard location in model directory
401
+ tokenizer_path = os.path.join(model_path, "tokenizer.model")
402
+ if not os.path.exists(tokenizer_path):
403
+ # Try original location
404
+ tokenizer_path = "/home/ubuntu/hindi_tokenizer/tokenizer.model"
405
+
406
+ if not os.path.exists(tokenizer_path):
407
+ raise FileNotFoundError(f"Could not find tokenizer at {tokenizer_path}")
408
+
409
+ self.tokenizer = SentencePieceTokenizerWrapper(tokenizer_path)
410
+ print(f"Loaded tokenizer from {tokenizer_path} with vocabulary size: {self.tokenizer.vocab_size}")
411
+
412
+ # Load model config
413
+ config_path = os.path.join(model_path, "config.json")
414
+ with open(config_path, "r") as f:
415
+ self.config = json.load(f)
416
+ print(f"Loaded model config with hidden_size={self.config['hidden_size']}")
417
+
418
+ # Load model
419
+ model_pt_path = os.path.join(model_path, "embedding_model.pt")
420
+
421
+ try:
422
+ # Support both PyTorch 2.6+ and older versions
423
+ try:
424
+ checkpoint = torch.load(model_pt_path, map_location=self.device, weights_only=False)
425
+ print("Loaded model using PyTorch 2.6+ style loading")
426
+ except TypeError:
427
+ checkpoint = torch.load(model_pt_path, map_location=self.device)
428
+ print("Loaded model using older PyTorch style loading")
429
+
430
+ # Create model
431
+ self.model = SentenceEmbeddingModel(self.config)
432
+
433
+ # Load state dict
434
+ if "model_state_dict" in checkpoint:
435
+ state_dict = checkpoint["model_state_dict"]
436
+ else:
437
+ state_dict = checkpoint
438
+
439
+ missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
440
+ print(f"Loaded model with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
441
+
442
+ # Move to device
443
+ self.model.to(self.device)
444
+ self.model.eval()
445
+ print("Model loaded successfully and placed in evaluation mode")
446
+
447
+ except Exception as e:
448
+ print(f"Error loading model: {e}")
449
+ raise RuntimeError(f"Failed to load the model: {e}")
450
+
451
+ def encode(self, sentences, batch_size=32, normalize=True):
452
+ """
453
+ Encode sentences to embeddings.
454
+
455
+ Args:
456
+ sentences: A string or list of strings to encode
457
+ batch_size: Batch size for encoding
458
+ normalize: Whether to normalize the embeddings
459
+
460
+ Returns:
461
+ Numpy array of embeddings
462
+ """
463
+ # Handle single sentence
464
+ if isinstance(sentences, str):
465
+ sentences = [sentences]
466
+
467
+ all_embeddings = []
468
+
469
+ # Process in batches
470
+ with torch.no_grad():
471
+ for i in range(0, len(sentences), batch_size):
472
+ batch = sentences[i:i+batch_size]
473
+
474
+ # Tokenize
475
+ inputs = self.tokenizer(
476
+ batch,
477
+ padding=True,
478
+ truncation=True,
479
+ max_length=self.config.get("max_position_embeddings", 128),
480
+ return_tensors="pt"
481
+ )
482
+
483
+ # Move to device
484
+ input_ids = inputs["input_ids"].to(self.device)
485
+ attention_mask = inputs["attention_mask"].to(self.device)
486
+
487
+ # Get embeddings
488
+ embeddings = self.model(input_ids, attention_mask)
489
+
490
+ # Move to CPU and convert to numpy
491
+ all_embeddings.append(embeddings.cpu().numpy())
492
+
493
+ # Concatenate all embeddings
494
+ all_embeddings = np.vstack(all_embeddings)
495
+
496
+ # Normalize if requested
497
+ if normalize:
498
+ all_embeddings = all_embeddings / np.linalg.norm(all_embeddings, axis=1, keepdims=True)
499
+
500
+ return all_embeddings
501
+
502
+ def compute_similarity(self, texts1, texts2=None):
503
+ """
504
+ Compute cosine similarity between texts.
505
+
506
+ Args:
507
+ texts1: First set of texts
508
+ texts2: Second set of texts. If None, compute similarity matrix within texts1.
509
+
510
+ Returns:
511
+ Similarity scores
512
+ """
513
+ embeddings1 = self.encode(texts1)
514
+
515
+ if texts2 is None:
516
+ # Compute similarity matrix within texts1
517
+ similarities = cosine_similarity(embeddings1)
518
+ return similarities
519
+ else:
520
+ # Compute similarity between texts1 and texts2
521
+ embeddings2 = self.encode(texts2)
522
+
523
+ if len(texts1) == len(texts2):
524
+ # Compute pairwise similarity when the number of texts match
525
+ return np.array([
526
+ cosine_similarity([e1], [e2])[0][0]
527
+ for e1, e2 in zip(embeddings1, embeddings2)
528
+ ])
529
+ else:
530
+ # Return full similarity matrix
531
+ return cosine_similarity(embeddings1, embeddings2)
532
+
533
+ def search(self, query, documents, top_k=5):
534
+ """
535
+ Search for similar documents to a query.
536
+
537
+ Args:
538
+ query: The query text
539
+ documents: List of documents to search
540
+ top_k: Number of top results to return
541
+
542
+ Returns:
543
+ List of dictionaries with document and score
544
+ """
545
+ # Get embeddings
546
+ query_embedding = self.encode([query])[0]
547
+ document_embeddings = self.encode(documents)
548
+
549
+ # Compute similarities
550
+ similarities = np.dot(document_embeddings, query_embedding)
551
+
552
+ # Get top indices
553
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
554
+
555
+ # Return results
556
+ results = []
557
+ for idx in top_indices:
558
+ results.append({
559
+ "document": documents[idx],
560
+ "score": float(similarities[idx])
561
+ })
562
+
563
+ return results
564
+
565
+ def evaluate_similarity_samples(self):
566
+ """Evaluate model on some standard similarity examples for Hindi"""
567
+ test_pairs = [
568
+ (
569
+ "मुझे हिंदी में पढ़ना बहुत पसंद है।",
570
+ "मैं हिंदी किताबें बहुत पसंद करता हूँ।"
571
+ ),
572
+ (
573
+ "आज मौसम बहुत अच्छा है।",
574
+ "आज बारिश हो रही है।"
575
+ ),
576
+ (
577
+ "भारत एक विशाल देश है।",
578
+ "भारत में कई भाषाएँ बोली जाती हैं।"
579
+ ),
580
+ (
581
+ "कंप्यूटर विज्ञान एक रोचक विषय है।",
582
+ "मैं कंप्यूटर साइंस का छात्र हूँ।"
583
+ ),
584
+ (
585
+ "मैं रोज सुबह योग करता हूँ।",
586
+ "स्वस्थ रहने के लिए व्यायाम जरूरी है।"
587
+ ),
588
+ # Add contrasting pairs to test discrimination
589
+ (
590
+ "मुझे हिंदी में पढ़ना बहुत पसंद है।",
591
+ "क्रिकेट भारत में सबसे लोकप्रिय खेल है।"
592
+ ),
593
+ (
594
+ "आज मौसम बहुत अच्छा है।",
595
+ "भारतीय व्यंजन दुनिया भर में मशहूर हैं।"
596
+ ),
597
+ (
598
+ "कंप्यूटर विज्ञान एक रोचक विषय है।",
599
+ "हिमालय दुनिया का सबसे ऊंचा पर्वत है।"
600
+ )
601
+ ]
602
+
603
+ print("Evaluating model on standard similarity samples:")
604
+ for i, (text1, text2) in enumerate(test_pairs):
605
+ similarity = self.compute_similarity([text1], [text2])[0]
606
+ print(f"\nPair {i+1}:")
607
+ print(f" Sentence 1: {text1}")
608
+ print(f" Sentence 2: {text2}")
609
+ print(f" Similarity: {similarity:.4f}")
610
+
611
+ return
612
+
613
+ def visualize_embeddings(self, sentences, labels=None, output_path="hindi_embeddings_visualization.png"):
614
+ """
615
+ Create a t-SNE visualization of the embeddings.
616
+
617
+ Args:
618
+ sentences: List of sentences to visualize
619
+ labels: Optional list of labels for the points
620
+ output_path: Path to save the visualization
621
+
622
+ Returns:
623
+ Path to the saved visualization
624
+ """
625
+ # Encode sentences
626
+ embeddings = self.encode(sentences)
627
+
628
+ # Apply t-SNE
629
+ tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
630
+ reduced_embeddings = tsne.fit_transform(embeddings)
631
+
632
+ # Create plot
633
+ plt.figure(figsize=(12, 10))
634
+
635
+ # Plot points
636
+ scatter = plt.scatter(
637
+ reduced_embeddings[:, 0],
638
+ reduced_embeddings[:, 1],
639
+ c=range(len(reduced_embeddings)),
640
+ cmap='viridis',
641
+ alpha=0.8,
642
+ s=100
643
+ )
644
+
645
+ # Add labels if provided
646
+ if labels:
647
+ for i, label in enumerate(labels):
648
+ plt.annotate(
649
+ label,
650
+ (reduced_embeddings[i, 0], reduced_embeddings[i, 1]),
651
+ fontsize=10,
652
+ alpha=0.7
653
+ )
654
+
655
+ plt.title("t-SNE Visualization of Hindi Sentence Embeddings", fontsize=16)
656
+ plt.xlabel("Dimension 1", fontsize=12)
657
+ plt.ylabel("Dimension 2", fontsize=12)
658
+ plt.colorbar(scatter, label="Sentence Index")
659
+ plt.grid(alpha=0.3)
660
+
661
+ # Save the figure
662
+ plt.tight_layout()
663
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
664
+ plt.close()
665
+
666
+ print(f"Visualization saved to {output_path}")
667
+ return output_path
668
+
669
+ def main():
670
+ # Create embedder
671
+ embedder = HindiEmbedder()
672
+
673
+ # Run sample evaluation
674
+ embedder.evaluate_similarity_samples()
675
+
676
+ # Example of semantic search
677
+ print("\nSemantic Search Example:")
678
+ query = "भारत की संस्कृति"
679
+ documents = [
680
+ "भारतीय संस्कृति दुनिया की सबसे प्राचीन संस्कृतियों में से एक है।",
681
+ "भारत की आबादी 1.3 अरब से अधिक है।",
682
+ "हिमालय पर्वत श्रृंखला भारत के उत्तर में स्थित है।",
683
+ "भारतीय व्यंजन में मसालों का प्रयोग किया जाता है।",
684
+ "भारत में 22 आधिकारिक भाषाएँ हैं।",
685
+ "संस्कृति लोगों के रहन-सहन का तरीका है।",
686
+ "भारत के विभिन्न राज्यों की अपनी अलग संस्कृति है।",
687
+ "रामायण और महाभारत भारतीय संस्कृति के महत्वपूर्ण हिस्से हैं।",
688
+ ]
689
+
690
+ results = embedder.search(query, documents)
691
+
692
+ print(f"Query: {query}")
693
+ print("Top results:")
694
+ for i, result in enumerate(results):
695
+ print(f"{i+1}. Score: {result['score']:.4f}")
696
+ print(f" {result['document']}")
697
+
698
+ # Create visualization example
699
+ print("\nCreating embedding visualization...")
700
+ visualization_sentences = [
701
+ "मुझे हिंदी में पढ़ना बहुत पसंद है।",
702
+ "मैं हिंदी किताबें बहुत पसंद करता हूँ।",
703
+ "आज मौसम बहुत अच्छा है।",
704
+ "आज बारिश हो रही है।",
705
+ "भारत एक विशाल देश है।",
706
+ "भारत में कई भाषाएँ बोली जाती हैं।",
707
+ "कंप्यूटर विज्ञान एक रोचक विषय है।",
708
+ "मैं कंप्यूटर साइंस का छात्र हूँ।",
709
+ "क्रिकेट भारत में सबसे लोकप्रिय खेल है।",
710
+ "भारतीय व्यंजन दुनिया भर में मशहूर हैं।",
711
+ "हिमालय दुनिया का सबसे ऊंचा पर्वत है।",
712
+ "गंगा भारत की सबसे पवित्र नदी है।",
713
+ "दिल्ली भारत की राजधानी है।",
714
+ "मुंबई भारत का आर्थिक केंद्र है।",
715
+ "तमिल, तेलुगु, कन्नड़ और मलयालम दक्षिण भारत की प्रमुख भाषाएँ हैं।"
716
+ ]
717
+
718
+ labels = ["पढ़ना", "किताबें", "मौसम", "बारिश", "भारत", "भाषाएँ", "कंप्यूटर",
719
+ "छात्र", "क्रिकेट", "व्यंजन", "हिमालय", "गंगा", "दिल्ली", "मुंबई", "भाषाएँ"]
720
+
721
+ embedder.visualize_embeddings(visualization_sentences, labels)
722
+
723
+ if __name__ == "__main__":
724
+ main()