ydshieh
commited on
Commit
·
15ecbe8
1
Parent(s):
f2e4555
update 8
Browse files
run_image_captioning_flax_reduced.py
CHANGED
|
@@ -503,6 +503,9 @@ def main():
|
|
| 503 |
raise ValueError(
|
| 504 |
"Decoder Config: Either a pretrained config or a model location for decoder is required."
|
| 505 |
)
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
# GPT2 only has bos/eos token but not decoder_start/pad token
|
| 508 |
if decoder_config.decoder_start_token_id is None:
|
|
@@ -520,7 +523,7 @@ def main():
|
|
| 520 |
encoder_dtype=getattr(jnp, model_args.dtype),
|
| 521 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
| 522 |
)
|
| 523 |
-
#
|
| 524 |
model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
|
| 525 |
|
| 526 |
if model_args.feature_extractor_name:
|
|
|
|
| 503 |
raise ValueError(
|
| 504 |
"Decoder Config: Either a pretrained config or a model location for decoder is required."
|
| 505 |
)
|
| 506 |
+
# necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
|
| 507 |
+
decoder_config.is_decoder = True
|
| 508 |
+
decoder_config.add_cross_attention = True
|
| 509 |
|
| 510 |
# GPT2 only has bos/eos token but not decoder_start/pad token
|
| 511 |
if decoder_config.decoder_start_token_id is None:
|
|
|
|
| 523 |
encoder_dtype=getattr(jnp, model_args.dtype),
|
| 524 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
| 525 |
)
|
| 526 |
+
# necessary to make Flax's generate() work
|
| 527 |
model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
|
| 528 |
|
| 529 |
if model_args.feature_extractor_name:
|