File size: 6,134 Bytes
5ff0fd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
class GaussianCoveragePooling(torch.nn.Module):
def __init__(self, coverage_chunks, sigma, alpha):
"""
Custom pooling layer that computes weighted mean pooling using Gaussian-based weights.
Args:
coverage_chunks (int): Number of weighted pooling operations (N).
sigma (float): Standard deviation for Gaussian weighting.
alpha (float): Weighting factor for merging with standard mean pooling.
"""
super().__init__()
self.coverage_chunks = coverage_chunks
self.sigma = sigma # Controls width of Gaussians
self.alpha = alpha # Blends standard mean with weighted mean
def forward(self, features, chunk_indicators=None):
"""
Computes weighted mean pooling using Gaussian-based weights.
Args:
self (SentenceTransformer): The model.
features (dict): The token embeddings and attention mask.
chunk_indicators (tensor[bz, 1]): Index indicators to return a specific chunk,
leave as None to return embeddings for all chunks. Mainly useful for training,
not inference. Leave as None for inference.
"""
# Get token embeddings and attention mask
token_embeddings = features[
"token_embeddings"
] # (batch_size, seq_len, hidden_dim)
attention_mask = (
features["attention_mask"].float().unsqueeze(-1)
) # (batch_size, seq_len, 1)
# Get shapes and devices
batch_size, seq_len, hidden_dim = token_embeddings.shape
device = token_embeddings.device
# Compute actual sequence lengths (ignoring padding)
# (batch_size, 1)
seq_lengths = attention_mask.squeeze(-1).sum(dim=1, keepdim=True)
max_seq_length = int(torch.max(seq_lengths).item())
# Standard mean pooling
sum_embeddings = torch.sum(token_embeddings * attention_mask, dim=1)
sum_mask = torch.sum(attention_mask, dim=1).clamp(min=1e-9)
standard_mean = sum_embeddings / sum_mask # (batch_size, hidden_dim)
# Compute chunk centers dynamically based on sequence length
chunk_positions = torch.linspace(0, 1, self.coverage_chunks + 2, device=device)[
1:-1
] # Excludes 0 and 1
chunk_centers = chunk_positions * seq_lengths # (batch_size, N)
# Token positions per sequence (batch_size, seq_len)
token_positions = (
torch.arange(seq_len, device=device).float().unsqueeze(0)
) # (1, seq_len)
# Compute Gaussian weights (batch_size, N, seq_len)
seq_lengths = seq_lengths.view(seq_lengths.shape[0], 1, 1).repeat(
1, self.coverage_chunks, max_seq_length
)
gaussians = torch.exp(
-0.5
* (
(token_positions.unsqueeze(1) - chunk_centers.unsqueeze(2))
/ (self.sigma * seq_lengths)
)
** 2
)
# Mask out padding and normalize Gaussian weights per sequence
# (batch_size, N, seq_len)
gaussians = gaussians * attention_mask.squeeze(-1).unsqueeze(1)
# Normalize against gaussian weights
gaussians /= gaussians.sum(dim=2, keepdim=True).clamp(min=1e-9)
# Compute weighted mean for each chunk (batch_size, N, hidden_dim)
weighted_means = torch.einsum(
"bns,bsh->bnh", gaussians.to(token_embeddings.dtype), token_embeddings
)
# Blend with standard mean pooling
# (batch_size, N, hidden_dim)
combined_embeddings = (1 - self.alpha) * standard_mean.unsqueeze(
1
) + self.alpha * weighted_means
# Add an embedding for the entire document at index 0
# (batch_size, N+1, hidden_dim)
combined_embeddings = torch.cat(
[torch.zeros_like(combined_embeddings[:, :1]), combined_embeddings], 1
)
combined_embeddings[:, 0:1, :] = standard_mean.unsqueeze(1)
# Select the indicator if provided
if chunk_indicators is not None:
combined_embeddings = combined_embeddings[
torch.arange(combined_embeddings.size(0)), chunk_indicators
]
# Normalize all the embeddings
combined_embeddings = torch.nn.functional.normalize(
combined_embeddings, p=2, dim=-1
)
# Flatten final embeddings (batch_size, hidden_dim * (N+1))
if chunk_indicators is None:
sentence_embedding = combined_embeddings.reshape(
batch_size, hidden_dim * (self.coverage_chunks + 1)
)
else:
sentence_embedding = combined_embeddings
# Return the final flattened entence embedding
features["sentence_embedding"] = sentence_embedding
return features
def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
"""
Add custom pooling layer that computes weighted mean pooling using Gaussian-based weights.
Args:
m (SentenceTransformer): The model to add pooling layer to.
coverage_chunks (int): Number of weighted pooling operations (N).
sigma (float): Standard deviation for Gaussian weighting.
alpha (float): Weighting factor for merging with standard mean pooling.
"""
if isinstance(m[1], GaussianCoveragePooling):
m = unuse_gaussian_coverage_pooling(m)
word_embedding_model = m[0]
custom_pooling = GaussianCoveragePooling(
coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha
)
old_pooling = m[1]
new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
new_m.old_pooling = {"old_pooling": old_pooling}
return new_m
def unuse_gaussian_coverage_pooling(m):
"""
Removes the custom pooling layer.
Args:
m (SentenceTransformer): The model to remove the pooling layer from.
"""
if isinstance(m[1], GaussianCoveragePooling):
new_m = m.__class__(modules=[m[0], m.old_pooling["old_pooling"]])
return new_m
else:
return m
|