ydshieh
commited on
Commit
·
9f6265f
1
Parent(s):
18936e9
fix tokenization hang on TPU VM when worker > 1
Browse files
run_image_captioning_flax.py
CHANGED
@@ -708,7 +708,7 @@ def main():
|
|
708 |
|
709 |
model_inputs["labels"] = labels["input_ids"]
|
710 |
decoder_input_ids = shift_tokens_right_fn(
|
711 |
-
|
712 |
)
|
713 |
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
714 |
|
|
|
708 |
|
709 |
model_inputs["labels"] = labels["input_ids"]
|
710 |
decoder_input_ids = shift_tokens_right_fn(
|
711 |
+
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
712 |
)
|
713 |
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
714 |
|