ydshieh commited on
Commit
a5be38e
·
1 Parent(s): fd1b4a2

improve code

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +13 -9
run_image_captioning_flax.py CHANGED
@@ -48,6 +48,7 @@ from huggingface_hub import Repository
48
  from transformers import (
49
  CONFIG_MAPPING,
50
  FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
 
51
  AutoConfig,
52
  AutoFeatureExtractor,
53
  AutoTokenizer,
@@ -73,6 +74,7 @@ except (LookupError, OSError):
73
 
74
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
75
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
 
76
 
77
 
78
  # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
@@ -196,15 +198,15 @@ class ModelArguments:
196
  )
197
  model_type: Optional[str] = field(
198
  default='vision-encoder-decoder',
199
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
200
  )
201
  encoder_model_type: Optional[str] = field(
202
  default=None,
203
- metadata={"help": "If training from scratch, pass a encoder model type from the list: " + ", ".join(MODEL_TYPES)},
204
  )
205
  decoder_model_type: Optional[str] = field(
206
  default=None,
207
- metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(MODEL_TYPES)},
208
  )
209
  config_name: Optional[str] = field(
210
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
@@ -570,7 +572,7 @@ def main():
570
  config.decoder.eos_token_id = eos_token_id
571
  config.decoder.pad_token_id = pad_token_id
572
 
573
- # Set `encoder-decoder` (top-level) specific config
574
  config.decoder_start_token_id = decoder_start_token_id
575
  config.bos_token_id = bos_token_id
576
  config.eos_token_id = eos_token_id
@@ -630,7 +632,7 @@ def main():
630
  decoder_dtype=getattr(jnp, model_args.dtype),
631
  )
632
 
633
- # Set `encoder-decoder` (top-level) specific config
634
  model.config.decoder_start_token_id = decoder_start_token_id
635
  model.config.bos_token_id = bos_token_id
636
  model.config.eos_token_id = eos_token_id
@@ -729,6 +731,7 @@ def main():
729
  shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
730
 
731
  def filter_fn(examples):
 
732
 
733
  bools = []
734
  for image_file in examples[image_column]:
@@ -1163,7 +1166,8 @@ def main():
1163
  if not os.path.isdir(os.path.join(training_args.output_dir)):
1164
  os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
1165
 
1166
- def save_ckpt(ckpt_dir: str, commit_msg: str =""):
 
1167
 
1168
  # save checkpoint after each epoch and push checkpoint to the hub
1169
  if jax.process_index() == 0:
@@ -1259,7 +1263,7 @@ def main():
1259
  json.dump(metrics, f, indent=4, sort_keys=True)
1260
 
1261
  # Update report
1262
- with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
1263
  fp.write(desc + '\n')
1264
 
1265
  # Save metrics (only for the evaluation/prediction being done along with training)
@@ -1325,7 +1329,7 @@ def main():
1325
 
1326
  logger.info(desc)
1327
 
1328
- with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
1329
  fp.write(desc + '\n')
1330
 
1331
  # Save metrics
@@ -1347,7 +1351,7 @@ def main():
1347
 
1348
  logger.info(desc)
1349
 
1350
- with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
1351
  fp.write(desc + '\n')
1352
 
1353
  # Save metrics
 
48
  from transformers import (
49
  CONFIG_MAPPING,
50
  FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
51
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
52
  AutoConfig,
53
  AutoFeatureExtractor,
54
  AutoTokenizer,
 
74
 
75
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
76
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
+ DECODER_MODEL_TYPES = tuple(conf.model_type for conf in list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()))
78
 
79
 
80
  # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
 
198
  )
199
  model_type: Optional[str] = field(
200
  default='vision-encoder-decoder',
201
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}
202
  )
203
  encoder_model_type: Optional[str] = field(
204
  default=None,
205
+ metadata={"help": "If training from scratch, pass a encoder model type from the library"}
206
  )
207
  decoder_model_type: Optional[str] = field(
208
  default=None,
209
+ metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(DECODER_MODEL_TYPES)}
210
  )
211
  config_name: Optional[str] = field(
212
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
572
  config.decoder.eos_token_id = eos_token_id
573
  config.decoder.pad_token_id = pad_token_id
574
 
575
+ # Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
576
  config.decoder_start_token_id = decoder_start_token_id
577
  config.bos_token_id = bos_token_id
578
  config.eos_token_id = eos_token_id
 
632
  decoder_dtype=getattr(jnp, model_args.dtype),
633
  )
634
 
635
+ # Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
636
  model.config.decoder_start_token_id = decoder_start_token_id
637
  model.config.bos_token_id = bos_token_id
638
  model.config.eos_token_id = eos_token_id
 
731
  shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
732
 
733
  def filter_fn(examples):
734
+ """remove problematic images"""
735
 
736
  bools = []
737
  for image_file in examples[image_column]:
 
1166
  if not os.path.isdir(os.path.join(training_args.output_dir)):
1167
  os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
1168
 
1169
+ def save_ckpt(ckpt_dir: str, commit_msg: str = ""):
1170
+ """save checkpoints and push to Hugging Face Hub if specified"""
1171
 
1172
  # save checkpoint after each epoch and push checkpoint to the hub
1173
  if jax.process_index() == 0:
 
1263
  json.dump(metrics, f, indent=4, sort_keys=True)
1264
 
1265
  # Update report
1266
+ with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1267
  fp.write(desc + '\n')
1268
 
1269
  # Save metrics (only for the evaluation/prediction being done along with training)
 
1329
 
1330
  logger.info(desc)
1331
 
1332
+ with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1333
  fp.write(desc + '\n')
1334
 
1335
  # Save metrics
 
1351
 
1352
  logger.info(desc)
1353
 
1354
+ with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1355
  fp.write(desc + '\n')
1356
 
1357
  # Save metrics