eyad-silx commited on
Commit
a18b96b
·
verified ·
1 Parent(s): 1059c6a

Update modeling_quasrav4.py

Browse files
Files changed (1) hide show
  1. modeling_quasrav4.py +2 -1
modeling_quasrav4.py CHANGED
@@ -6,6 +6,7 @@ from typing import Optional, Tuple, List, Union
6
 
7
  from transformers import PreTrainedModel
8
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
9
  from transformers.utils import logging
10
 
11
  from .configuration_quasrav4 import InfinityFormerConfig
@@ -213,7 +214,7 @@ class InfinityFormerModel(InfinityFormerPreTrainedModel):
213
  return (hidden_states,)
214
  return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
215
 
216
- class InfinityFormerForCausalLM(InfinityFormerPreTrainedModel):
217
  _auto_class = "AutoModelForCausalLM"
218
 
219
  def __init__(self, config: InfinityFormerConfig):
 
6
 
7
  from transformers import PreTrainedModel
8
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
9
+ from transformers.generation import GenerationMixin
10
  from transformers.utils import logging
11
 
12
  from .configuration_quasrav4 import InfinityFormerConfig
 
214
  return (hidden_states,)
215
  return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
216
 
217
+ class InfinityFormerForCausalLM(GenerationMixin, InfinityFormerPreTrainedModel):
218
  _auto_class = "AutoModelForCausalLM"
219
 
220
  def __init__(self, config: InfinityFormerConfig):