|
import torch |
|
|
|
|
|
class GaussianCoveragePooling(torch.nn.Module): |
|
def __init__(self, coverage_chunks, sigma, alpha): |
|
""" |
|
Custom pooling layer that computes weighted mean pooling using Gaussian-based weights. |
|
|
|
Args: |
|
coverage_chunks (int): Number of weighted pooling operations (N). |
|
sigma (float): Standard deviation for Gaussian weighting. |
|
alpha (float): Weighting factor for merging with standard mean pooling. |
|
""" |
|
super().__init__() |
|
self.coverage_chunks = coverage_chunks |
|
self.sigma = sigma |
|
self.alpha = alpha |
|
|
|
def forward(self, features, chunk_indicators=None): |
|
""" |
|
Computes weighted mean pooling using Gaussian-based weights. |
|
|
|
Args: |
|
self (SentenceTransformer): The model. |
|
features (dict): The token embeddings and attention mask. |
|
chunk_indicators (tensor[bz, 1]): Index indicators to return a specific chunk, |
|
leave as None to return embeddings for all chunks. Mainly useful for training, |
|
not inference. Leave as None for inference. |
|
""" |
|
|
|
|
|
token_embeddings = features[ |
|
"token_embeddings" |
|
] |
|
attention_mask = ( |
|
features["attention_mask"].float().unsqueeze(-1) |
|
) |
|
|
|
|
|
batch_size, seq_len, hidden_dim = token_embeddings.shape |
|
device = token_embeddings.device |
|
|
|
|
|
|
|
seq_lengths = attention_mask.squeeze(-1).sum(dim=1, keepdim=True) |
|
max_seq_length = int(torch.max(seq_lengths).item()) |
|
|
|
|
|
sum_embeddings = torch.sum(token_embeddings * attention_mask, dim=1) |
|
sum_mask = torch.sum(attention_mask, dim=1).clamp(min=1e-9) |
|
standard_mean = sum_embeddings / sum_mask |
|
|
|
|
|
chunk_positions = torch.linspace(0, 1, self.coverage_chunks + 2, device=device)[ |
|
1:-1 |
|
] |
|
chunk_centers = chunk_positions * seq_lengths |
|
|
|
|
|
token_positions = ( |
|
torch.arange(seq_len, device=device).float().unsqueeze(0) |
|
) |
|
|
|
|
|
seq_lengths = seq_lengths.view(seq_lengths.shape[0], 1, 1).repeat( |
|
1, self.coverage_chunks, max_seq_length |
|
) |
|
gaussians = torch.exp( |
|
-0.5 |
|
* ( |
|
(token_positions.unsqueeze(1) - chunk_centers.unsqueeze(2)) |
|
/ (self.sigma * seq_lengths) |
|
) |
|
** 2 |
|
) |
|
|
|
|
|
|
|
gaussians = gaussians * attention_mask.squeeze(-1).unsqueeze(1) |
|
|
|
|
|
gaussians /= gaussians.sum(dim=2, keepdim=True).clamp(min=1e-9) |
|
|
|
|
|
weighted_means = torch.einsum( |
|
"bns,bsh->bnh", gaussians.to(token_embeddings.dtype), token_embeddings |
|
) |
|
|
|
|
|
|
|
combined_embeddings = (1 - self.alpha) * standard_mean.unsqueeze( |
|
1 |
|
) + self.alpha * weighted_means |
|
|
|
|
|
|
|
combined_embeddings = torch.cat( |
|
[torch.zeros_like(combined_embeddings[:, :1]), combined_embeddings], 1 |
|
) |
|
combined_embeddings[:, 0:1, :] = standard_mean.unsqueeze(1) |
|
|
|
|
|
if chunk_indicators is not None: |
|
combined_embeddings = combined_embeddings[ |
|
torch.arange(combined_embeddings.size(0)), chunk_indicators |
|
] |
|
|
|
|
|
combined_embeddings = torch.nn.functional.normalize( |
|
combined_embeddings, p=2, dim=-1 |
|
) |
|
|
|
|
|
if chunk_indicators is None: |
|
sentence_embedding = combined_embeddings.reshape( |
|
batch_size, hidden_dim * (self.coverage_chunks + 1) |
|
) |
|
else: |
|
sentence_embedding = combined_embeddings |
|
|
|
|
|
features["sentence_embedding"] = sentence_embedding |
|
return features |
|
|
|
|
|
def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0): |
|
""" |
|
Add custom pooling layer that computes weighted mean pooling using Gaussian-based weights. |
|
|
|
Args: |
|
m (SentenceTransformer): The model to add pooling layer to. |
|
coverage_chunks (int): Number of weighted pooling operations (N). |
|
sigma (float): Standard deviation for Gaussian weighting. |
|
alpha (float): Weighting factor for merging with standard mean pooling. |
|
""" |
|
if isinstance(m[1], GaussianCoveragePooling): |
|
m = unuse_gaussian_coverage_pooling(m) |
|
word_embedding_model = m[0] |
|
custom_pooling = GaussianCoveragePooling( |
|
coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha |
|
) |
|
old_pooling = m[1] |
|
new_m = m.__class__(modules=[word_embedding_model, custom_pooling]) |
|
new_m.old_pooling = {"old_pooling": old_pooling} |
|
return new_m |
|
|
|
|
|
def unuse_gaussian_coverage_pooling(m): |
|
""" |
|
Removes the custom pooling layer. |
|
|
|
Args: |
|
m (SentenceTransformer): The model to remove the pooling layer from. |
|
""" |
|
|
|
if isinstance(m[1], GaussianCoveragePooling): |
|
new_m = m.__class__(modules=[m[0], m.old_pooling["old_pooling"]]) |
|
return new_m |
|
else: |
|
return m |
|
|