Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +34 -17
ts_generation_mixin.py
CHANGED
|
@@ -6,8 +6,38 @@ from transformers.generation import validate_stopping_criteria, EosTokenCriteria
|
|
| 6 |
from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
|
| 7 |
from transformers.utils import ModelOutput
|
| 8 |
|
|
|
|
| 9 |
class TSGenerationMixin(GenerationMixin):
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def _greedy_search(
|
| 12 |
self,
|
| 13 |
input_ids: torch.Tensor,
|
|
@@ -26,19 +56,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 26 |
**model_kwargs,
|
| 27 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
| 28 |
input_ids = input_ids.to(self.device)
|
| 29 |
-
|
| 30 |
-
if len(input_ids.shape) == 2:
|
| 31 |
-
batch_size, cur_len = input_ids.shape
|
| 32 |
-
if cur_len < self.config.input_token_len:
|
| 33 |
-
raise ValueError(
|
| 34 |
-
f"Input length must be at least {self.config.input_token_len}")
|
| 35 |
-
elif cur_len % self.config.input_token_len != 0:
|
| 36 |
-
new_len = (cur_len // self.config.input_token_len) * \
|
| 37 |
-
self.config.input_token_len
|
| 38 |
-
input_ids = input_ids[:, -new_len:]
|
| 39 |
-
else:
|
| 40 |
-
raise ValueError('Input shape must be: [batch_size, seq_len]')
|
| 41 |
-
|
| 42 |
# init values
|
| 43 |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 44 |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
@@ -106,9 +124,8 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 106 |
batch_size, dtype=torch.long, device=input_ids.device)
|
| 107 |
model_kwargs["cache_position"] = torch.arange(
|
| 108 |
cur_len, device=input_ids.device)
|
| 109 |
-
true_seq_len =
|
| 110 |
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
|
| 111 |
-
|
| 112 |
max_length = stopping_criteria.max_length
|
| 113 |
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 114 |
# prepare model inputs
|
|
@@ -129,7 +146,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 129 |
if synced_gpus and this_peer_finished:
|
| 130 |
continue # don't waste resources running the code we don't need
|
| 131 |
|
| 132 |
-
next_token_logits = outputs.logits
|
| 133 |
|
| 134 |
# pre-process distribution
|
| 135 |
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
|
@@ -212,7 +229,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 212 |
past_key_values=model_kwargs.get("past_key_values"),
|
| 213 |
)
|
| 214 |
else:
|
| 215 |
-
return input_ids[:, -(max_length -
|
| 216 |
|
| 217 |
def _update_model_kwargs_for_generation(
|
| 218 |
self,
|
|
|
|
| 6 |
from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
|
| 7 |
from transformers.utils import ModelOutput
|
| 8 |
|
| 9 |
+
|
| 10 |
class TSGenerationMixin(GenerationMixin):
|
| 11 |
|
| 12 |
+
@torch.no_grad()
|
| 13 |
+
def generate(
|
| 14 |
+
self,
|
| 15 |
+
inputs: Optional[torch.Tensor] = None,
|
| 16 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 17 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 18 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 19 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
| 20 |
+
synced_gpus: Optional[bool] = None,
|
| 21 |
+
assistant_model: Optional["PreTrainedModel"] = None,
|
| 22 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 23 |
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
| 24 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 25 |
+
**kwargs,
|
| 26 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 27 |
+
if len(inputs.shape) == 2:
|
| 28 |
+
batch_size, cur_len = inputs.shape
|
| 29 |
+
if cur_len < self.config.input_token_len:
|
| 30 |
+
raise ValueError(
|
| 31 |
+
f"Input length must be at least {self.config.input_token_len}")
|
| 32 |
+
elif cur_len % self.config.input_token_len != 0:
|
| 33 |
+
new_len = (cur_len // self.config.input_token_len) * \
|
| 34 |
+
self.config.input_token_len
|
| 35 |
+
inputs = inputs[:, -new_len:]
|
| 36 |
+
else:
|
| 37 |
+
raise ValueError('Input shape must be: [batch_size, seq_len]')
|
| 38 |
+
return super().generate(inputs=inputs, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, **kwargs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
def _greedy_search(
|
| 42 |
self,
|
| 43 |
input_ids: torch.Tensor,
|
|
|
|
| 56 |
**model_kwargs,
|
| 57 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
| 58 |
input_ids = input_ids.to(self.device)
|
| 59 |
+
batch_size, cur_len = input_ids.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# init values
|
| 61 |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 62 |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
|
| 124 |
batch_size, dtype=torch.long, device=input_ids.device)
|
| 125 |
model_kwargs["cache_position"] = torch.arange(
|
| 126 |
cur_len, device=input_ids.device)
|
| 127 |
+
true_seq_len = cur_len // self.config.input_token_len
|
| 128 |
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
|
|
|
|
| 129 |
max_length = stopping_criteria.max_length
|
| 130 |
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 131 |
# prepare model inputs
|
|
|
|
| 146 |
if synced_gpus and this_peer_finished:
|
| 147 |
continue # don't waste resources running the code we don't need
|
| 148 |
|
| 149 |
+
next_token_logits = outputs.logits
|
| 150 |
|
| 151 |
# pre-process distribution
|
| 152 |
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
|
|
|
| 229 |
past_key_values=model_kwargs.get("past_key_values"),
|
| 230 |
)
|
| 231 |
else:
|
| 232 |
+
return input_ids[:, -(max_length - cur_len):]
|
| 233 |
|
| 234 |
def _update_model_kwargs_for_generation(
|
| 235 |
self,
|