Update modeling_quasrav4.py
Browse files- modeling_quasrav4.py +2 -1
modeling_quasrav4.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Optional, Tuple, List, Union
|
|
6 |
|
7 |
from transformers import PreTrainedModel
|
8 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
9 |
from transformers.utils import logging
|
10 |
|
11 |
from .configuration_quasrav4 import InfinityFormerConfig
|
@@ -213,7 +214,7 @@ class InfinityFormerModel(InfinityFormerPreTrainedModel):
|
|
213 |
return (hidden_states,)
|
214 |
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
|
215 |
|
216 |
-
class InfinityFormerForCausalLM(InfinityFormerPreTrainedModel):
|
217 |
_auto_class = "AutoModelForCausalLM"
|
218 |
|
219 |
def __init__(self, config: InfinityFormerConfig):
|
|
|
6 |
|
7 |
from transformers import PreTrainedModel
|
8 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
9 |
+
from transformers.generation import GenerationMixin
|
10 |
from transformers.utils import logging
|
11 |
|
12 |
from .configuration_quasrav4 import InfinityFormerConfig
|
|
|
214 |
return (hidden_states,)
|
215 |
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
|
216 |
|
217 |
+
class InfinityFormerForCausalLM(GenerationMixin, InfinityFormerPreTrainedModel):
|
218 |
_auto_class = "AutoModelForCausalLM"
|
219 |
|
220 |
def __init__(self, config: InfinityFormerConfig):
|