ydshieh commited on
Commit
68f6bad
·
1 Parent(s): 2c5a28b
Files changed (1) hide show
  1. run_image_captioning_flax_reduced.py +37 -185
run_image_captioning_flax_reduced.py CHANGED
@@ -53,8 +53,10 @@ from transformers import (
53
  AutoFeatureExtractor,
54
  AutoTokenizer,
55
  FlaxAutoModelForVision2Seq,
 
56
  HfArgumentParser,
57
  is_tensorboard_available,
 
58
  )
59
  from transformers.file_utils import get_full_repo_name, is_offline_mode
60
 
@@ -171,13 +173,6 @@ class ModelArguments:
171
  Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
172
  """
173
 
174
- model_name_or_path: Optional[str] = field(
175
- default=None,
176
- metadata={
177
- "help": "The model checkpoint for weights initialization."
178
- "Don't set if you want to train a model from scratch."
179
- },
180
- )
181
  encoder_model_name_or_path: Optional[str] = field(
182
  default=None,
183
  metadata={
@@ -192,26 +187,6 @@ class ModelArguments:
192
  "Don't set if you want to train a decoder model from scratch."
193
  },
194
  )
195
- model_type: Optional[str] = field(
196
- default="vision-encoder-decoder",
197
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
198
- )
199
- encoder_model_type: Optional[str] = field(
200
- default=None,
201
- metadata={
202
- "help": "If training from scratch, pass a vision encoder model type from the library. For example, 'vit'"
203
- },
204
- )
205
- decoder_model_type: Optional[str] = field(
206
- default=None,
207
- metadata={
208
- "help": "If training from scratch, pass a decoder model type from the list: "
209
- + ", ".join(DECODER_MODEL_TYPES)
210
- },
211
- )
212
- config_name: Optional[str] = field(
213
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
214
- )
215
  encoder_config_name: Optional[str] = field(
216
  default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
217
  )
@@ -499,170 +474,47 @@ def main():
499
  encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
500
  decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
501
 
502
- # Use explicit specified config
503
- if model_args.config_name:
504
- config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
505
- # Use pretrained model's config
506
- elif model_args.model_name_or_path:
507
- config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
508
- # Use specified `model_type` (default to `vision-encoder-decoder`)
 
 
 
509
  else:
 
 
 
510
 
511
- if model_args.model_type not in MODEL_TYPES:
512
- raise ValueError(
513
- f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
514
- )
515
- config_class = CONFIG_MAPPING[model_args.model_type]
516
-
517
- # Deal with encoder-decoder models that require specifying encoder/decoder
518
- if hasattr(config_class, "from_encoder_decoder_configs"):
519
-
520
- # Use explicit specified encoder config
521
- if model_args.encoder_config_name:
522
- encoder_config = AutoConfig.from_pretrained(
523
- model_args.encoder_config_name, cache_dir=encoder_cache_dir
524
- )
525
- # Use pretrained encoder model's config
526
- elif model_args.encoder_model_name_or_path:
527
- encoder_config = AutoConfig.from_pretrained(
528
- model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir
529
- )
530
- # Use specified encoder model type
531
- elif model_args.encoder_model_type:
532
- encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
533
- logger.warning("You are instantiating a new config instance from scratch for the encoder.")
534
- else:
535
- raise ValueError(
536
- "Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required."
537
- )
538
-
539
- # Use explicit specified decoder config
540
- if model_args.decoder_config_name:
541
- decoder_config = AutoConfig.from_pretrained(
542
- model_args.decoder_config_name, cache_dir=decoder_cache_dir
543
- )
544
- # Use pretrained decoder model's config
545
- elif model_args.decoder_model_name_or_path:
546
- decoder_config = AutoConfig.from_pretrained(
547
- model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir
548
- )
549
- # Use specified decoder model type
550
- elif model_args.decoder_model_type:
551
- decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
552
- logger.warning("You are instantiating a new config instance from scratch for the decoder.")
553
- else:
554
- raise ValueError(
555
- "Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required."
556
- )
557
-
558
- logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
559
- decoder_config.is_decoder = True
560
- decoder_config.add_cross_attention = True
561
-
562
- config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config)
563
- # For self-contained model
564
- else:
565
- config = config_class()
566
- logger.warning("You are instantiating a new config instance from scratch.")
567
-
568
- decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
569
- if not decoder_start_token_id and getattr(config, "decoder", None):
570
- decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
571
- bos_token_id = getattr(config, "bos_token_id", None)
572
- if not bos_token_id and getattr(config, "decoder", None):
573
- bos_token_id = getattr(config.decoder, "bos_token_id", None)
574
- eos_token_id = getattr(config, "eos_token_id", None)
575
- if not eos_token_id and getattr(config, "decoder", None):
576
- eos_token_id = getattr(config.decoder, "eos_token_id", None)
577
- pad_token_id = getattr(config, "pad_token_id", None)
578
- if not pad_token_id and getattr(config, "decoder", None):
579
- pad_token_id = getattr(config.decoder, "pad_token_id", None)
580
-
581
- if decoder_start_token_id is None:
582
- decoder_start_token_id = bos_token_id
583
- if pad_token_id is None:
584
- pad_token_id = eos_token_id
585
-
586
- if getattr(config, "decoder", None):
587
- config.decoder.decoder_start_token_id = decoder_start_token_id
588
- config.decoder.bos_token_id = bos_token_id
589
- config.decoder.eos_token_id = eos_token_id
590
- config.decoder.pad_token_id = pad_token_id
591
-
592
- # Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
593
- config.decoder_start_token_id = decoder_start_token_id
594
- config.bos_token_id = bos_token_id
595
- config.eos_token_id = eos_token_id
596
- config.pad_token_id = pad_token_id
597
-
598
- if model_args.model_name_or_path:
599
- model = FlaxAutoModelForVision2Seq.from_pretrained(
600
- model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
601
  )
602
  else:
603
- # model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
604
- model = FlaxAutoModelForVision2Seq.from_config(
605
- config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
606
  )
607
- model_class = model.__class__
608
-
609
- # encoder_class = FlaxAutoModel
610
- # decoder_class = FlaxAutoModelForCausalLM
611
- module = model.module.bind(model.params)
612
- encoder_class_name = type(module.encoder).__name__.replace("Module", "Model")
613
- decoder_class_name = type(module.decoder).__name__.replace("Module", "Model")
614
- encoder_class = getattr(transformers, encoder_class_name, None)
615
- decoder_class = getattr(transformers, decoder_class_name, None)
616
-
617
- if hasattr(model_class, "from_encoder_decoder_pretrained"):
618
-
619
- if model_args.encoder_model_name_or_path:
620
- encoder = encoder_class.from_pretrained(
621
- model_args.encoder_model_name_or_path,
622
- config=config.encoder,
623
- seed=training_args.seed,
624
- dtype=getattr(jnp, model_args.dtype),
625
- )
626
- else:
627
- encoder = encoder_class(
628
- config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
629
- )
630
- logger.warning("You are instantiating a new model instance from scratch for the encoder.")
631
-
632
- if model_args.decoder_model_name_or_path:
633
- decoder = decoder_class.from_pretrained(
634
- model_args.decoder_model_name_or_path,
635
- config=config.decoder,
636
- seed=training_args.seed,
637
- dtype=getattr(jnp, model_args.dtype),
638
- )
639
- else:
640
- decoder = decoder_class(
641
- config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
642
- )
643
- logger.warning("You are instantiating a new model instance from scratch for the decoder.")
644
 
645
- model = model_class.from_encoder_decoder_pretrained(
646
- model_args.encoder_model_name_or_path,
647
- model_args.decoder_model_name_or_path,
648
- encoder_model=encoder,
649
- decoder_model=decoder,
650
- encoder_config=config.encoder,
651
- decoder_config=config.decoder,
652
- encoder_seed=training_args.seed,
653
- decoder_seed=training_args.seed,
654
- encoder_dtype=getattr(jnp, model_args.dtype),
655
- decoder_dtype=getattr(jnp, model_args.dtype),
656
- )
657
-
658
- # Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
659
- model.config.decoder_start_token_id = decoder_start_token_id
660
- model.config.bos_token_id = bos_token_id
661
- model.config.eos_token_id = eos_token_id
662
- model.config.pad_token_id = pad_token_id
663
-
664
- else:
665
- logger.warning("You are instantiating a new model instance from scratch.")
666
 
667
  feature_extractor = None
668
  if model_args.feature_extractor_name:
 
53
  AutoFeatureExtractor,
54
  AutoTokenizer,
55
  FlaxAutoModelForVision2Seq,
56
+ FlaxVisionEncoderDecoderModel,
57
  HfArgumentParser,
58
  is_tensorboard_available,
59
+ VisionEncoderDecoderConfig,
60
  )
61
  from transformers.file_utils import get_full_repo_name, is_offline_mode
62
 
 
173
  Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
174
  """
175
 
 
 
 
 
 
 
 
176
  encoder_model_name_or_path: Optional[str] = field(
177
  default=None,
178
  metadata={
 
187
  "Don't set if you want to train a decoder model from scratch."
188
  },
189
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  encoder_config_name: Optional[str] = field(
191
  default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
192
  )
 
474
  encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
475
  decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
476
 
477
+ # Use explicit specified encoder config
478
+ if model_args.encoder_config_name:
479
+ encoder_config = AutoConfig.from_pretrained(
480
+ model_args.encoder_config_name, cache_dir=encoder_cache_dir
481
+ )
482
+ # Use pretrained encoder model's config
483
+ elif model_args.encoder_model_name_or_path:
484
+ encoder_config = AutoConfig.from_pretrained(
485
+ model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir
486
+ )
487
  else:
488
+ raise ValueError(
489
+ "Encoder Config: Either a pretrained config or a model location for encoder is required."
490
+ )
491
 
492
+ # Use explicit specified decoder config
493
+ if model_args.decoder_config_name:
494
+ decoder_config = AutoConfig.from_pretrained(
495
+ model_args.decoder_config_name, cache_dir=decoder_cache_dir
496
+ )
497
+ # Use pretrained decoder model's config
498
+ elif model_args.decoder_model_name_or_path:
499
+ decoder_config = AutoConfig.from_pretrained(
500
+ model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  )
502
  else:
503
+ raise ValueError(
504
+ "Decoder Config: Either a pretrained config or a model location for decoder is required."
 
505
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
508
+ model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
509
+ encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
510
+ decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
511
+ encoder_config=config.encoder,
512
+ decoder_config=config.decoder,
513
+ encoder_seed=training_args.seed,
514
+ decoder_seed=training_args.seed,
515
+ encoder_dtype=getattr(jnp, model_args.dtype),
516
+ decoder_dtype=getattr(jnp, model_args.dtype),
517
+ )
 
 
 
 
 
 
 
 
 
 
518
 
519
  feature_extractor = None
520
  if model_args.feature_extractor_name: