fix: force correct mixed dtype after HF load
Browse files- modeling_hyena.py +2 -1
modeling_hyena.py
CHANGED
|
@@ -45,8 +45,9 @@ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
|
|
| 45 |
)
|
| 46 |
self.vocab_size = vocab_size
|
| 47 |
self.post_init()
|
|
|
|
| 48 |
|
| 49 |
-
def
|
| 50 |
self.backbone.to_bfloat16_except_poles_residues()
|
| 51 |
|
| 52 |
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
|
|
|
|
| 45 |
)
|
| 46 |
self.vocab_size = vocab_size
|
| 47 |
self.post_init()
|
| 48 |
+
self.force_dtype()
|
| 49 |
|
| 50 |
+
def force_dtype(self):
|
| 51 |
self.backbone.to_bfloat16_except_poles_residues()
|
| 52 |
|
| 53 |
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
|