Upload SegmentNT
Browse files- config.json +1 -1
- modeling_segment_nt.py +29 -27
- pytorch_model.bin +2 -2
    	
        config.json
    CHANGED
    
    | @@ -40,7 +40,7 @@ | |
| 40 | 
             
              "num_layers_head": 2,
         | 
| 41 | 
             
              "pad_token_id": 1,
         | 
| 42 | 
             
              "position_embedding_type": "rotary",
         | 
| 43 | 
            -
              "rescaling_factor":  | 
| 44 | 
             
              "tie_word_embeddings": false,
         | 
| 45 | 
             
              "token_dropout": false,
         | 
| 46 | 
             
              "torch_dtype": "float32",
         | 
|  | |
| 40 | 
             
              "num_layers_head": 2,
         | 
| 41 | 
             
              "pad_token_id": 1,
         | 
| 42 | 
             
              "position_embedding_type": "rotary",
         | 
| 43 | 
            +
              "rescaling_factor": 2.44140625,
         | 
| 44 | 
             
              "tie_word_embeddings": false,
         | 
| 45 | 
             
              "token_dropout": false,
         | 
| 46 | 
             
              "torch_dtype": "float32",
         | 
    	
        modeling_segment_nt.py
    CHANGED
    
    | @@ -115,56 +115,58 @@ class RotaryEmbedding(torch.nn.Module): | |
| 115 | 
             
                    super().__init__()
         | 
| 116 |  | 
| 117 | 
             
                    # Extract argument from the config
         | 
| 118 | 
            -
                    rescaling_factor = rotary_embedding_config.rescaling_factor
         | 
| 119 | 
            -
                    upper_freq = 10000
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                    if rescaling_factor is None:
         | 
| 122 | 
            -
                        inv_freq = 1.0 / (upper_freq ** (torch.arange(0, dim, 2).float() / dim))
         | 
| 123 | 
            -
                    else:
         | 
| 124 | 
            -
                        updated_base = upper_freq * (
         | 
| 125 | 
            -
                            rescaling_factor ** (dim / (dim - 2))
         | 
| 126 | 
            -
                        )
         | 
| 127 | 
            -
                        inv_freq = 1.0 / (
         | 
| 128 | 
            -
                            updated_base ** (torch.arange(0, dim, 2).float()  / dim)
         | 
| 129 | 
            -
                        )
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                    self.register_buffer("inv_freq", inv_freq)
         | 
| 132 |  | 
| 133 | 
             
                    self._seq_len_cached = None
         | 
| 134 | 
             
                    self._cos_cached = None
         | 
| 135 | 
             
                    self._sin_cached = None
         | 
| 136 |  | 
| 137 | 
            -
             | 
|  | |
|  | |
| 138 | 
             
                    seq_len = x.shape[seq_dimension]
         | 
| 139 |  | 
| 140 | 
             
                    # Reset the tables if the sequence length has changed,
         | 
| 141 | 
             
                    # or if we're on a new device (possibly due to tracing for instance)
         | 
| 142 | 
            -
                     | 
| 143 | 
            -
             | 
| 144 | 
            -
                         | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
                        emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
         | 
| 149 |  | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 |  | 
| 153 | 
             
                    return self._cos_cached, self._sin_cached
         | 
| 154 |  | 
| 155 | 
             
                def forward(
         | 
| 156 | 
             
                    self, q: torch.Tensor, k: torch.Tensor
         | 
| 157 | 
             
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 158 | 
            -
                     | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 161 |  | 
|  | |
|  | |
|  | |
|  | |
| 162 | 
             
                    return (
         | 
| 163 | 
             
                        apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
         | 
| 164 | 
             
                        apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
         | 
| 165 | 
             
                    )
         | 
| 166 |  | 
| 167 |  | 
|  | |
| 168 | 
             
            class EsmContactPredictionHead(nn.Module):
         | 
| 169 | 
             
                """Performs symmetrization, apc, and computes a logistic regression on the output features"""
         | 
| 170 |  | 
|  | |
| 115 | 
             
                    super().__init__()
         | 
| 116 |  | 
| 117 | 
             
                    # Extract argument from the config
         | 
| 118 | 
            +
                    self.rescaling_factor = rotary_embedding_config.rescaling_factor
         | 
| 119 | 
            +
                    self.upper_freq = 10000
         | 
| 120 | 
            +
                    self.dim = dim
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 121 |  | 
| 122 | 
             
                    self._seq_len_cached = None
         | 
| 123 | 
             
                    self._cos_cached = None
         | 
| 124 | 
             
                    self._sin_cached = None
         | 
| 125 |  | 
| 126 | 
            +
             | 
| 127 | 
            +
                    
         | 
| 128 | 
            +
                def _compute_cos_sin_tables(self, x, inv_freq, seq_dimension=2):
         | 
| 129 | 
             
                    seq_len = x.shape[seq_dimension]
         | 
| 130 |  | 
| 131 | 
             
                    # Reset the tables if the sequence length has changed,
         | 
| 132 | 
             
                    # or if we're on a new device (possibly due to tracing for instance)
         | 
| 133 | 
            +
                    self._seq_len_cached = seq_len
         | 
| 134 | 
            +
                    t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
         | 
| 135 | 
            +
                        inv_freq
         | 
| 136 | 
            +
                    )
         | 
| 137 | 
            +
                    freqs = torch.outer(t, inv_freq)
         | 
| 138 | 
            +
                    emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
         | 
|  | |
| 139 |  | 
| 140 | 
            +
                    self._cos_cached = emb.cos()[None, None, :, :]
         | 
| 141 | 
            +
                    self._sin_cached = emb.sin()[None, None, :, :]
         | 
| 142 |  | 
| 143 | 
             
                    return self._cos_cached, self._sin_cached
         | 
| 144 |  | 
| 145 | 
             
                def forward(
         | 
| 146 | 
             
                    self, q: torch.Tensor, k: torch.Tensor
         | 
| 147 | 
             
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 148 | 
            +
                    
         | 
| 149 | 
            +
                    if self.rescaling_factor is None:
         | 
| 150 | 
            +
                        inv_freq = 1.0 / (self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim))
         | 
| 151 | 
            +
                    else:
         | 
| 152 | 
            +
                        updated_base = self.upper_freq * (
         | 
| 153 | 
            +
                            self.rescaling_factor ** (self.dim / (self.dim - 2))
         | 
| 154 | 
            +
                        )
         | 
| 155 | 
            +
                        inv_freq = 1.0 / (
         | 
| 156 | 
            +
                            updated_base ** (torch.arange(0, self.dim, 2).float()  / self.dim)
         | 
| 157 | 
            +
                        )
         | 
| 158 |  | 
| 159 | 
            +
                    self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
         | 
| 160 | 
            +
                        k, inv_freq, seq_dimension=-2, 
         | 
| 161 | 
            +
                    )
         | 
| 162 | 
            +
                    
         | 
| 163 | 
             
                    return (
         | 
| 164 | 
             
                        apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
         | 
| 165 | 
             
                        apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
         | 
| 166 | 
             
                    )
         | 
| 167 |  | 
| 168 |  | 
| 169 | 
            +
             | 
| 170 | 
             
            class EsmContactPredictionHead(nn.Module):
         | 
| 171 | 
             
                """Performs symmetrization, apc, and computes a logistic regression on the output features"""
         | 
| 172 |  | 
    	
        pytorch_model.bin
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256: | 
| 3 | 
            -
            size  | 
|  | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:bf3b06784e943efd3f33b6059ad921218490cd691d2a0ffb11db3da8ef424b5d
         | 
| 3 | 
            +
            size 2237465429
         | 

