Update modularStarEncoder.py
Browse files- modularStarEncoder.py +5 -0
    	
        modularStarEncoder.py
    CHANGED
    
    | @@ -204,7 +204,12 @@ def get_pooling_mask( | |
| 204 |  | 
| 205 | 
             
                repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
         | 
| 206 |  | 
|  | |
|  | |
|  | |
|  | |
| 207 | 
             
                ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
         | 
|  | |
| 208 |  | 
| 209 | 
             
                pooling_mask = (repeated_idx <= ranges).long()
         | 
| 210 |  | 
|  | |
| 204 |  | 
| 205 | 
             
                repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
         | 
| 206 |  | 
| 207 | 
            +
                DEVICE = input_ids.get_device()
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                if DEVICE<0:
         | 
| 210 | 
            +
                    DEVICE = "cpu"
         | 
| 211 | 
             
                ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
         | 
| 212 | 
            +
                ranges = ranges.to(DEVICE)
         | 
| 213 |  | 
| 214 | 
             
                pooling_mask = (repeated_idx <= ranges).long()
         | 
| 215 |  | 
