compatible with DirectML/ROCm (#5)
Browse files- compatible with DirectML/ROCm (5bc5aff72b4d8fbdb10f7befa1473a741ecec8b5)
Co-authored-by: Davin Wang <[email protected]>
- modeling_chatglm.py +2 -1
modeling_chatglm.py
CHANGED
|
@@ -16,6 +16,7 @@ from transformers.modeling_outputs import (
|
|
| 16 |
BaseModelOutputWithPast,
|
| 17 |
CausalLMOutputWithPast,
|
| 18 |
)
|
|
|
|
| 19 |
from transformers.modeling_utils import PreTrainedModel
|
| 20 |
from transformers.utils import logging
|
| 21 |
from transformers.generation.logits_process import LogitsProcessor
|
|
@@ -1138,7 +1139,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1138 |
)
|
| 1139 |
logits_warper = self._get_logits_warper(generation_config)
|
| 1140 |
|
| 1141 |
-
unfinished_sequences =
|
| 1142 |
scores = None
|
| 1143 |
while True:
|
| 1144 |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
| 16 |
BaseModelOutputWithPast,
|
| 17 |
CausalLMOutputWithPast,
|
| 18 |
)
|
| 19 |
+
|
| 20 |
from transformers.modeling_utils import PreTrainedModel
|
| 21 |
from transformers.utils import logging
|
| 22 |
from transformers.generation.logits_process import LogitsProcessor
|
|
|
|
| 1139 |
)
|
| 1140 |
logits_warper = self._get_logits_warper(generation_config)
|
| 1141 |
|
| 1142 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], device=input_ids.device, dtype=input_ids.dtype)
|
| 1143 |
scores = None
|
| 1144 |
while True:
|
| 1145 |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|