ydshieh commited on
Commit
15ecbe8
·
1 Parent(s): f2e4555
run_image_captioning_flax_reduced.py CHANGED
@@ -503,6 +503,9 @@ def main():
503
  raise ValueError(
504
  "Decoder Config: Either a pretrained config or a model location for decoder is required."
505
  )
 
 
 
506
 
507
  # GPT2 only has bos/eos token but not decoder_start/pad token
508
  if decoder_config.decoder_start_token_id is None:
@@ -520,7 +523,7 @@ def main():
520
  encoder_dtype=getattr(jnp, model_args.dtype),
521
  decoder_dtype=getattr(jnp, model_args.dtype),
522
  )
523
- # Necessary for Flax's generate()
524
  model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
525
 
526
  if model_args.feature_extractor_name:
 
503
  raise ValueError(
504
  "Decoder Config: Either a pretrained config or a model location for decoder is required."
505
  )
506
+ # necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
507
+ decoder_config.is_decoder = True
508
+ decoder_config.add_cross_attention = True
509
 
510
  # GPT2 only has bos/eos token but not decoder_start/pad token
511
  if decoder_config.decoder_start_token_id is None:
 
523
  encoder_dtype=getattr(jnp, model_args.dtype),
524
  decoder_dtype=getattr(jnp, model_args.dtype),
525
  )
526
+ # necessary to make Flax's generate() work
527
  model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
528
 
529
  if model_args.feature_extractor_name: