File size: 6,134 Bytes
5ff0fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch


class GaussianCoveragePooling(torch.nn.Module):
    def __init__(self, coverage_chunks, sigma, alpha):
        """
        Custom pooling layer that computes weighted mean pooling using Gaussian-based weights.

        Args:
            coverage_chunks (int): Number of weighted pooling operations (N).
            sigma (float): Standard deviation for Gaussian weighting.
            alpha (float): Weighting factor for merging with standard mean pooling.
        """
        super().__init__()
        self.coverage_chunks = coverage_chunks
        self.sigma = sigma  # Controls width of Gaussians
        self.alpha = alpha  # Blends standard mean with weighted mean

    def forward(self, features, chunk_indicators=None):
        """
        Computes weighted mean pooling using Gaussian-based weights.

        Args:
            self (SentenceTransformer): The model.
            features (dict): The token embeddings and attention mask.
            chunk_indicators (tensor[bz, 1]): Index indicators to return a specific chunk,
            leave as None to return embeddings for all chunks. Mainly useful for training,
            not inference. Leave as None for inference.
        """

        # Get token embeddings and attention mask
        token_embeddings = features[
            "token_embeddings"
        ]  # (batch_size, seq_len, hidden_dim)
        attention_mask = (
            features["attention_mask"].float().unsqueeze(-1)
        )  # (batch_size, seq_len, 1)

        # Get shapes and devices
        batch_size, seq_len, hidden_dim = token_embeddings.shape
        device = token_embeddings.device

        # Compute actual sequence lengths (ignoring padding)
        # (batch_size, 1)
        seq_lengths = attention_mask.squeeze(-1).sum(dim=1, keepdim=True)
        max_seq_length = int(torch.max(seq_lengths).item())

        # Standard mean pooling
        sum_embeddings = torch.sum(token_embeddings * attention_mask, dim=1)
        sum_mask = torch.sum(attention_mask, dim=1).clamp(min=1e-9)
        standard_mean = sum_embeddings / sum_mask  # (batch_size, hidden_dim)

        # Compute chunk centers dynamically based on sequence length
        chunk_positions = torch.linspace(0, 1, self.coverage_chunks + 2, device=device)[
            1:-1
        ]  # Excludes 0 and 1
        chunk_centers = chunk_positions * seq_lengths  # (batch_size, N)

        # Token positions per sequence (batch_size, seq_len)
        token_positions = (
            torch.arange(seq_len, device=device).float().unsqueeze(0)
        )  # (1, seq_len)

        # Compute Gaussian weights (batch_size, N, seq_len)
        seq_lengths = seq_lengths.view(seq_lengths.shape[0], 1, 1).repeat(
            1, self.coverage_chunks, max_seq_length
        )
        gaussians = torch.exp(
            -0.5
            * (
                (token_positions.unsqueeze(1) - chunk_centers.unsqueeze(2))
                / (self.sigma * seq_lengths)
            )
            ** 2
        )

        # Mask out padding and normalize Gaussian weights per sequence
        # (batch_size, N, seq_len)
        gaussians = gaussians * attention_mask.squeeze(-1).unsqueeze(1)

        # Normalize against gaussian weights
        gaussians /= gaussians.sum(dim=2, keepdim=True).clamp(min=1e-9)

        # Compute weighted mean for each chunk (batch_size, N, hidden_dim)
        weighted_means = torch.einsum(
            "bns,bsh->bnh", gaussians.to(token_embeddings.dtype), token_embeddings
        )

        # Blend with standard mean pooling
        # (batch_size, N, hidden_dim)
        combined_embeddings = (1 - self.alpha) * standard_mean.unsqueeze(
            1
        ) + self.alpha * weighted_means

        # Add an embedding for the entire document at index 0
        # (batch_size, N+1, hidden_dim)
        combined_embeddings = torch.cat(
            [torch.zeros_like(combined_embeddings[:, :1]), combined_embeddings], 1
        )
        combined_embeddings[:, 0:1, :] = standard_mean.unsqueeze(1)

        # Select the indicator if provided
        if chunk_indicators is not None:
            combined_embeddings = combined_embeddings[
                torch.arange(combined_embeddings.size(0)), chunk_indicators
            ]

        # Normalize all the embeddings
        combined_embeddings = torch.nn.functional.normalize(
            combined_embeddings, p=2, dim=-1
        )

        # Flatten final embeddings (batch_size, hidden_dim * (N+1))
        if chunk_indicators is None:
            sentence_embedding = combined_embeddings.reshape(
                batch_size, hidden_dim * (self.coverage_chunks + 1)
            )
        else:
            sentence_embedding = combined_embeddings

        # Return the final flattened entence embedding
        features["sentence_embedding"] = sentence_embedding
        return features


def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
    """
    Add custom pooling layer that computes weighted mean pooling using Gaussian-based weights.

    Args:
        m (SentenceTransformer): The model to add pooling layer to.
        coverage_chunks (int): Number of weighted pooling operations (N).
        sigma (float): Standard deviation for Gaussian weighting.
        alpha (float): Weighting factor for merging with standard mean pooling.
    """
    if isinstance(m[1], GaussianCoveragePooling):
        m = unuse_gaussian_coverage_pooling(m)
    word_embedding_model = m[0]
    custom_pooling = GaussianCoveragePooling(
        coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha
    )
    old_pooling = m[1]
    new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
    new_m.old_pooling = {"old_pooling": old_pooling}
    return new_m


def unuse_gaussian_coverage_pooling(m):
    """
    Removes the custom pooling layer.

    Args:
        m (SentenceTransformer): The model to remove the pooling layer from.
    """

    if isinstance(m[1], GaussianCoveragePooling):
        new_m = m.__class__(modules=[m[0], m.old_pooling["old_pooling"]])
        return new_m
    else:
        return m