ydshieh
commited on
Commit
·
bcae421
1
Parent(s):
15ecbe8
update 9
Browse files
run_image_captioning_flax_reduced.py
CHANGED
|
@@ -524,6 +524,7 @@ def main():
|
|
| 524 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
| 525 |
)
|
| 526 |
# necessary to make Flax's generate() work
|
|
|
|
| 527 |
model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
|
| 528 |
|
| 529 |
if model_args.feature_extractor_name:
|
|
|
|
| 524 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
| 525 |
)
|
| 526 |
# necessary to make Flax's generate() work
|
| 527 |
+
model.config.eos_token_id = decoder_config.eos_token_id
|
| 528 |
model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
|
| 529 |
|
| 530 |
if model_args.feature_extractor_name:
|