ydshieh
commited on
Commit
·
68f6bad
1
Parent(s):
2c5a28b
update 1
Browse files- run_image_captioning_flax_reduced.py +37 -185
run_image_captioning_flax_reduced.py
CHANGED
|
@@ -53,8 +53,10 @@ from transformers import (
|
|
| 53 |
AutoFeatureExtractor,
|
| 54 |
AutoTokenizer,
|
| 55 |
FlaxAutoModelForVision2Seq,
|
|
|
|
| 56 |
HfArgumentParser,
|
| 57 |
is_tensorboard_available,
|
|
|
|
| 58 |
)
|
| 59 |
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
| 60 |
|
|
@@ -171,13 +173,6 @@ class ModelArguments:
|
|
| 171 |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 172 |
"""
|
| 173 |
|
| 174 |
-
model_name_or_path: Optional[str] = field(
|
| 175 |
-
default=None,
|
| 176 |
-
metadata={
|
| 177 |
-
"help": "The model checkpoint for weights initialization."
|
| 178 |
-
"Don't set if you want to train a model from scratch."
|
| 179 |
-
},
|
| 180 |
-
)
|
| 181 |
encoder_model_name_or_path: Optional[str] = field(
|
| 182 |
default=None,
|
| 183 |
metadata={
|
|
@@ -192,26 +187,6 @@ class ModelArguments:
|
|
| 192 |
"Don't set if you want to train a decoder model from scratch."
|
| 193 |
},
|
| 194 |
)
|
| 195 |
-
model_type: Optional[str] = field(
|
| 196 |
-
default="vision-encoder-decoder",
|
| 197 |
-
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
| 198 |
-
)
|
| 199 |
-
encoder_model_type: Optional[str] = field(
|
| 200 |
-
default=None,
|
| 201 |
-
metadata={
|
| 202 |
-
"help": "If training from scratch, pass a vision encoder model type from the library. For example, 'vit'"
|
| 203 |
-
},
|
| 204 |
-
)
|
| 205 |
-
decoder_model_type: Optional[str] = field(
|
| 206 |
-
default=None,
|
| 207 |
-
metadata={
|
| 208 |
-
"help": "If training from scratch, pass a decoder model type from the list: "
|
| 209 |
-
+ ", ".join(DECODER_MODEL_TYPES)
|
| 210 |
-
},
|
| 211 |
-
)
|
| 212 |
-
config_name: Optional[str] = field(
|
| 213 |
-
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
| 214 |
-
)
|
| 215 |
encoder_config_name: Optional[str] = field(
|
| 216 |
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
|
| 217 |
)
|
|
@@ -499,170 +474,47 @@ def main():
|
|
| 499 |
encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
|
| 500 |
decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
|
| 501 |
|
| 502 |
-
# Use explicit specified config
|
| 503 |
-
if model_args.
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
|
|
|
|
|
|
|
|
|
| 509 |
else:
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
# Use explicit specified encoder config
|
| 521 |
-
if model_args.encoder_config_name:
|
| 522 |
-
encoder_config = AutoConfig.from_pretrained(
|
| 523 |
-
model_args.encoder_config_name, cache_dir=encoder_cache_dir
|
| 524 |
-
)
|
| 525 |
-
# Use pretrained encoder model's config
|
| 526 |
-
elif model_args.encoder_model_name_or_path:
|
| 527 |
-
encoder_config = AutoConfig.from_pretrained(
|
| 528 |
-
model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir
|
| 529 |
-
)
|
| 530 |
-
# Use specified encoder model type
|
| 531 |
-
elif model_args.encoder_model_type:
|
| 532 |
-
encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
|
| 533 |
-
logger.warning("You are instantiating a new config instance from scratch for the encoder.")
|
| 534 |
-
else:
|
| 535 |
-
raise ValueError(
|
| 536 |
-
"Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required."
|
| 537 |
-
)
|
| 538 |
-
|
| 539 |
-
# Use explicit specified decoder config
|
| 540 |
-
if model_args.decoder_config_name:
|
| 541 |
-
decoder_config = AutoConfig.from_pretrained(
|
| 542 |
-
model_args.decoder_config_name, cache_dir=decoder_cache_dir
|
| 543 |
-
)
|
| 544 |
-
# Use pretrained decoder model's config
|
| 545 |
-
elif model_args.decoder_model_name_or_path:
|
| 546 |
-
decoder_config = AutoConfig.from_pretrained(
|
| 547 |
-
model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir
|
| 548 |
-
)
|
| 549 |
-
# Use specified decoder model type
|
| 550 |
-
elif model_args.decoder_model_type:
|
| 551 |
-
decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
|
| 552 |
-
logger.warning("You are instantiating a new config instance from scratch for the decoder.")
|
| 553 |
-
else:
|
| 554 |
-
raise ValueError(
|
| 555 |
-
"Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required."
|
| 556 |
-
)
|
| 557 |
-
|
| 558 |
-
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
| 559 |
-
decoder_config.is_decoder = True
|
| 560 |
-
decoder_config.add_cross_attention = True
|
| 561 |
-
|
| 562 |
-
config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config)
|
| 563 |
-
# For self-contained model
|
| 564 |
-
else:
|
| 565 |
-
config = config_class()
|
| 566 |
-
logger.warning("You are instantiating a new config instance from scratch.")
|
| 567 |
-
|
| 568 |
-
decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
|
| 569 |
-
if not decoder_start_token_id and getattr(config, "decoder", None):
|
| 570 |
-
decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
|
| 571 |
-
bos_token_id = getattr(config, "bos_token_id", None)
|
| 572 |
-
if not bos_token_id and getattr(config, "decoder", None):
|
| 573 |
-
bos_token_id = getattr(config.decoder, "bos_token_id", None)
|
| 574 |
-
eos_token_id = getattr(config, "eos_token_id", None)
|
| 575 |
-
if not eos_token_id and getattr(config, "decoder", None):
|
| 576 |
-
eos_token_id = getattr(config.decoder, "eos_token_id", None)
|
| 577 |
-
pad_token_id = getattr(config, "pad_token_id", None)
|
| 578 |
-
if not pad_token_id and getattr(config, "decoder", None):
|
| 579 |
-
pad_token_id = getattr(config.decoder, "pad_token_id", None)
|
| 580 |
-
|
| 581 |
-
if decoder_start_token_id is None:
|
| 582 |
-
decoder_start_token_id = bos_token_id
|
| 583 |
-
if pad_token_id is None:
|
| 584 |
-
pad_token_id = eos_token_id
|
| 585 |
-
|
| 586 |
-
if getattr(config, "decoder", None):
|
| 587 |
-
config.decoder.decoder_start_token_id = decoder_start_token_id
|
| 588 |
-
config.decoder.bos_token_id = bos_token_id
|
| 589 |
-
config.decoder.eos_token_id = eos_token_id
|
| 590 |
-
config.decoder.pad_token_id = pad_token_id
|
| 591 |
-
|
| 592 |
-
# Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
|
| 593 |
-
config.decoder_start_token_id = decoder_start_token_id
|
| 594 |
-
config.bos_token_id = bos_token_id
|
| 595 |
-
config.eos_token_id = eos_token_id
|
| 596 |
-
config.pad_token_id = pad_token_id
|
| 597 |
-
|
| 598 |
-
if model_args.model_name_or_path:
|
| 599 |
-
model = FlaxAutoModelForVision2Seq.from_pretrained(
|
| 600 |
-
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 601 |
)
|
| 602 |
else:
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 606 |
)
|
| 607 |
-
model_class = model.__class__
|
| 608 |
-
|
| 609 |
-
# encoder_class = FlaxAutoModel
|
| 610 |
-
# decoder_class = FlaxAutoModelForCausalLM
|
| 611 |
-
module = model.module.bind(model.params)
|
| 612 |
-
encoder_class_name = type(module.encoder).__name__.replace("Module", "Model")
|
| 613 |
-
decoder_class_name = type(module.decoder).__name__.replace("Module", "Model")
|
| 614 |
-
encoder_class = getattr(transformers, encoder_class_name, None)
|
| 615 |
-
decoder_class = getattr(transformers, decoder_class_name, None)
|
| 616 |
-
|
| 617 |
-
if hasattr(model_class, "from_encoder_decoder_pretrained"):
|
| 618 |
-
|
| 619 |
-
if model_args.encoder_model_name_or_path:
|
| 620 |
-
encoder = encoder_class.from_pretrained(
|
| 621 |
-
model_args.encoder_model_name_or_path,
|
| 622 |
-
config=config.encoder,
|
| 623 |
-
seed=training_args.seed,
|
| 624 |
-
dtype=getattr(jnp, model_args.dtype),
|
| 625 |
-
)
|
| 626 |
-
else:
|
| 627 |
-
encoder = encoder_class(
|
| 628 |
-
config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 629 |
-
)
|
| 630 |
-
logger.warning("You are instantiating a new model instance from scratch for the encoder.")
|
| 631 |
-
|
| 632 |
-
if model_args.decoder_model_name_or_path:
|
| 633 |
-
decoder = decoder_class.from_pretrained(
|
| 634 |
-
model_args.decoder_model_name_or_path,
|
| 635 |
-
config=config.decoder,
|
| 636 |
-
seed=training_args.seed,
|
| 637 |
-
dtype=getattr(jnp, model_args.dtype),
|
| 638 |
-
)
|
| 639 |
-
else:
|
| 640 |
-
decoder = decoder_class(
|
| 641 |
-
config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 642 |
-
)
|
| 643 |
-
logger.warning("You are instantiating a new model instance from scratch for the decoder.")
|
| 644 |
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
)
|
| 657 |
-
|
| 658 |
-
# Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
|
| 659 |
-
model.config.decoder_start_token_id = decoder_start_token_id
|
| 660 |
-
model.config.bos_token_id = bos_token_id
|
| 661 |
-
model.config.eos_token_id = eos_token_id
|
| 662 |
-
model.config.pad_token_id = pad_token_id
|
| 663 |
-
|
| 664 |
-
else:
|
| 665 |
-
logger.warning("You are instantiating a new model instance from scratch.")
|
| 666 |
|
| 667 |
feature_extractor = None
|
| 668 |
if model_args.feature_extractor_name:
|
|
|
|
| 53 |
AutoFeatureExtractor,
|
| 54 |
AutoTokenizer,
|
| 55 |
FlaxAutoModelForVision2Seq,
|
| 56 |
+
FlaxVisionEncoderDecoderModel,
|
| 57 |
HfArgumentParser,
|
| 58 |
is_tensorboard_available,
|
| 59 |
+
VisionEncoderDecoderConfig,
|
| 60 |
)
|
| 61 |
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
| 62 |
|
|
|
|
| 173 |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 174 |
"""
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
encoder_model_name_or_path: Optional[str] = field(
|
| 177 |
default=None,
|
| 178 |
metadata={
|
|
|
|
| 187 |
"Don't set if you want to train a decoder model from scratch."
|
| 188 |
},
|
| 189 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
encoder_config_name: Optional[str] = field(
|
| 191 |
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
|
| 192 |
)
|
|
|
|
| 474 |
encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
|
| 475 |
decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
|
| 476 |
|
| 477 |
+
# Use explicit specified encoder config
|
| 478 |
+
if model_args.encoder_config_name:
|
| 479 |
+
encoder_config = AutoConfig.from_pretrained(
|
| 480 |
+
model_args.encoder_config_name, cache_dir=encoder_cache_dir
|
| 481 |
+
)
|
| 482 |
+
# Use pretrained encoder model's config
|
| 483 |
+
elif model_args.encoder_model_name_or_path:
|
| 484 |
+
encoder_config = AutoConfig.from_pretrained(
|
| 485 |
+
model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir
|
| 486 |
+
)
|
| 487 |
else:
|
| 488 |
+
raise ValueError(
|
| 489 |
+
"Encoder Config: Either a pretrained config or a model location for encoder is required."
|
| 490 |
+
)
|
| 491 |
|
| 492 |
+
# Use explicit specified decoder config
|
| 493 |
+
if model_args.decoder_config_name:
|
| 494 |
+
decoder_config = AutoConfig.from_pretrained(
|
| 495 |
+
model_args.decoder_config_name, cache_dir=decoder_cache_dir
|
| 496 |
+
)
|
| 497 |
+
# Use pretrained decoder model's config
|
| 498 |
+
elif model_args.decoder_model_name_or_path:
|
| 499 |
+
decoder_config = AutoConfig.from_pretrained(
|
| 500 |
+
model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
)
|
| 502 |
else:
|
| 503 |
+
raise ValueError(
|
| 504 |
+
"Decoder Config: Either a pretrained config or a model location for decoder is required."
|
|
|
|
| 505 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
+
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
|
| 508 |
+
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
| 509 |
+
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
|
| 510 |
+
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
|
| 511 |
+
encoder_config=config.encoder,
|
| 512 |
+
decoder_config=config.decoder,
|
| 513 |
+
encoder_seed=training_args.seed,
|
| 514 |
+
decoder_seed=training_args.seed,
|
| 515 |
+
encoder_dtype=getattr(jnp, model_args.dtype),
|
| 516 |
+
decoder_dtype=getattr(jnp, model_args.dtype),
|
| 517 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
feature_extractor = None
|
| 520 |
if model_args.feature_extractor_name:
|