ydshieh commited on
Commit
9f6265f
·
1 Parent(s): 18936e9

fix tokenization hang on TPU VM when worker > 1

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +1 -1
run_image_captioning_flax.py CHANGED
@@ -708,7 +708,7 @@ def main():
708
 
709
  model_inputs["labels"] = labels["input_ids"]
710
  decoder_input_ids = shift_tokens_right_fn(
711
- jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
712
  )
713
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
714
 
 
708
 
709
  model_inputs["labels"] = labels["input_ids"]
710
  decoder_input_ids = shift_tokens_right_fn(
711
+ labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
712
  )
713
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
714