add explicit cast where running without autocast causes issues
Browse files- attention.py +1 -1
    	
        attention.py
    CHANGED
    
    | @@ -55,7 +55,7 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_ | |
| 55 | 
             
                attn_weight = torch.softmax(attn_weight, dim=-1)
         | 
| 56 | 
             
                if dropout_p:
         | 
| 57 | 
             
                    attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
         | 
| 58 | 
            -
                out = attn_weight.matmul(v)
         | 
| 59 | 
             
                out = rearrange(out, 'b h s d -> b s (h d)')
         | 
| 60 | 
             
                if needs_weights:
         | 
| 61 | 
             
                    return (out, attn_weight, past_key_value)
         | 
|  | |
| 55 | 
             
                attn_weight = torch.softmax(attn_weight, dim=-1)
         | 
| 56 | 
             
                if dropout_p:
         | 
| 57 | 
             
                    attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
         | 
| 58 | 
            +
                out = attn_weight.to(v.dtype).matmul(v)
         | 
| 59 | 
             
                out = rearrange(out, 'b h s d -> b s (h d)')
         | 
| 60 | 
             
                if needs_weights:
         | 
| 61 | 
             
                    return (out, attn_weight, past_key_value)
         | 

