import torch class GaussianCoveragePooling(torch.nn.Module): def __init__(self, coverage_chunks, sigma, alpha): """ Custom pooling layer that computes weighted mean pooling using Gaussian-based weights. Args: coverage_chunks (int): Number of weighted pooling operations (N). sigma (float): Standard deviation for Gaussian weighting. alpha (float): Weighting factor for merging with standard mean pooling. """ super().__init__() self.coverage_chunks = coverage_chunks self.sigma = sigma # Controls width of Gaussians self.alpha = alpha # Blends standard mean with weighted mean def forward(self, features, chunk_indicators=None): """ Computes weighted mean pooling using Gaussian-based weights. Args: self (SentenceTransformer): The model. features (dict): The token embeddings and attention mask. chunk_indicators (tensor[bz, 1]): Index indicators to return a specific chunk, leave as None to return embeddings for all chunks. Mainly useful for training, not inference. Leave as None for inference. """ # Get token embeddings and attention mask token_embeddings = features[ "token_embeddings" ] # (batch_size, seq_len, hidden_dim) attention_mask = ( features["attention_mask"].float().unsqueeze(-1) ) # (batch_size, seq_len, 1) # Get shapes and devices batch_size, seq_len, hidden_dim = token_embeddings.shape device = token_embeddings.device # Compute actual sequence lengths (ignoring padding) # (batch_size, 1) seq_lengths = attention_mask.squeeze(-1).sum(dim=1, keepdim=True) max_seq_length = int(torch.max(seq_lengths).item()) # Standard mean pooling sum_embeddings = torch.sum(token_embeddings * attention_mask, dim=1) sum_mask = torch.sum(attention_mask, dim=1).clamp(min=1e-9) standard_mean = sum_embeddings / sum_mask # (batch_size, hidden_dim) # Compute chunk centers dynamically based on sequence length chunk_positions = torch.linspace(0, 1, self.coverage_chunks + 2, device=device)[ 1:-1 ] # Excludes 0 and 1 chunk_centers = chunk_positions * seq_lengths # (batch_size, N) # Token positions per sequence (batch_size, seq_len) token_positions = ( torch.arange(seq_len, device=device).float().unsqueeze(0) ) # (1, seq_len) # Compute Gaussian weights (batch_size, N, seq_len) seq_lengths = seq_lengths.view(seq_lengths.shape[0], 1, 1).repeat( 1, self.coverage_chunks, max_seq_length ) gaussians = torch.exp( -0.5 * ( (token_positions.unsqueeze(1) - chunk_centers.unsqueeze(2)) / (self.sigma * seq_lengths) ) ** 2 ) # Mask out padding and normalize Gaussian weights per sequence # (batch_size, N, seq_len) gaussians = gaussians * attention_mask.squeeze(-1).unsqueeze(1) # Normalize against gaussian weights gaussians /= gaussians.sum(dim=2, keepdim=True).clamp(min=1e-9) # Compute weighted mean for each chunk (batch_size, N, hidden_dim) weighted_means = torch.einsum( "bns,bsh->bnh", gaussians.to(token_embeddings.dtype), token_embeddings ) # Blend with standard mean pooling # (batch_size, N, hidden_dim) combined_embeddings = (1 - self.alpha) * standard_mean.unsqueeze( 1 ) + self.alpha * weighted_means # Add an embedding for the entire document at index 0 # (batch_size, N+1, hidden_dim) combined_embeddings = torch.cat( [torch.zeros_like(combined_embeddings[:, :1]), combined_embeddings], 1 ) combined_embeddings[:, 0:1, :] = standard_mean.unsqueeze(1) # Select the indicator if provided if chunk_indicators is not None: combined_embeddings = combined_embeddings[ torch.arange(combined_embeddings.size(0)), chunk_indicators ] # Normalize all the embeddings combined_embeddings = torch.nn.functional.normalize( combined_embeddings, p=2, dim=-1 ) # Flatten final embeddings (batch_size, hidden_dim * (N+1)) if chunk_indicators is None: sentence_embedding = combined_embeddings.reshape( batch_size, hidden_dim * (self.coverage_chunks + 1) ) else: sentence_embedding = combined_embeddings # Return the final flattened entence embedding features["sentence_embedding"] = sentence_embedding return features def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0): """ Add custom pooling layer that computes weighted mean pooling using Gaussian-based weights. Args: m (SentenceTransformer): The model to add pooling layer to. coverage_chunks (int): Number of weighted pooling operations (N). sigma (float): Standard deviation for Gaussian weighting. alpha (float): Weighting factor for merging with standard mean pooling. """ if isinstance(m[1], GaussianCoveragePooling): m = unuse_gaussian_coverage_pooling(m) word_embedding_model = m[0] custom_pooling = GaussianCoveragePooling( coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha ) old_pooling = m[1] new_m = m.__class__(modules=[word_embedding_model, custom_pooling]) new_m.old_pooling = {"old_pooling": old_pooling} return new_m def unuse_gaussian_coverage_pooling(m): """ Removes the custom pooling layer. Args: m (SentenceTransformer): The model to remove the pooling layer from. """ if isinstance(m[1], GaussianCoveragePooling): new_m = m.__class__(modules=[m[0], m.old_pooling["old_pooling"]]) return new_m else: return m