ydshieh
commited on
Commit
·
b1c34ff
1
Parent(s):
7834fdb
update 14
Browse files
run_image_captioning_flax_reduced.py
CHANGED
|
@@ -469,20 +469,15 @@ def main():
|
|
| 469 |
|
| 470 |
# Load pretrained model and tokenizer
|
| 471 |
|
| 472 |
-
encoder_cache_dir, decoder_cache_dir = None, None
|
| 473 |
-
if model_args.cache_dir:
|
| 474 |
-
encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
|
| 475 |
-
decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
|
| 476 |
-
|
| 477 |
# Use explicit specified encoder config
|
| 478 |
if model_args.encoder_config_name:
|
| 479 |
encoder_config = AutoConfig.from_pretrained(
|
| 480 |
-
model_args.encoder_config_name, cache_dir=
|
| 481 |
)
|
| 482 |
# Use pretrained encoder model's config
|
| 483 |
elif model_args.encoder_model_name_or_path:
|
| 484 |
encoder_config = AutoConfig.from_pretrained(
|
| 485 |
-
model_args.encoder_model_name_or_path, cache_dir=
|
| 486 |
)
|
| 487 |
else:
|
| 488 |
raise ValueError(
|
|
@@ -492,12 +487,12 @@ def main():
|
|
| 492 |
# Use explicit specified decoder config
|
| 493 |
if model_args.decoder_config_name:
|
| 494 |
decoder_config = AutoConfig.from_pretrained(
|
| 495 |
-
model_args.decoder_config_name, cache_dir=
|
| 496 |
)
|
| 497 |
# Use pretrained decoder model's config
|
| 498 |
elif model_args.decoder_model_name_or_path:
|
| 499 |
decoder_config = AutoConfig.from_pretrained(
|
| 500 |
-
model_args.decoder_model_name_or_path, cache_dir=
|
| 501 |
)
|
| 502 |
else:
|
| 503 |
raise ValueError(
|
|
|
|
| 469 |
|
| 470 |
# Load pretrained model and tokenizer
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
# Use explicit specified encoder config
|
| 473 |
if model_args.encoder_config_name:
|
| 474 |
encoder_config = AutoConfig.from_pretrained(
|
| 475 |
+
model_args.encoder_config_name, cache_dir=model_args.cache_dir
|
| 476 |
)
|
| 477 |
# Use pretrained encoder model's config
|
| 478 |
elif model_args.encoder_model_name_or_path:
|
| 479 |
encoder_config = AutoConfig.from_pretrained(
|
| 480 |
+
model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
|
| 481 |
)
|
| 482 |
else:
|
| 483 |
raise ValueError(
|
|
|
|
| 487 |
# Use explicit specified decoder config
|
| 488 |
if model_args.decoder_config_name:
|
| 489 |
decoder_config = AutoConfig.from_pretrained(
|
| 490 |
+
model_args.decoder_config_name, cache_dir=model_args.cache_dir
|
| 491 |
)
|
| 492 |
# Use pretrained decoder model's config
|
| 493 |
elif model_args.decoder_model_name_or_path:
|
| 494 |
decoder_config = AutoConfig.from_pretrained(
|
| 495 |
+
model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir
|
| 496 |
)
|
| 497 |
else:
|
| 498 |
raise ValueError(
|