ydshieh commited on
Commit
89bf8e9
·
1 Parent(s): 6b61700
Files changed (1) hide show
  1. run_image_captioning_flax.py +53 -42
run_image_captioning_flax.py CHANGED
@@ -112,7 +112,15 @@ class TrainingArguments:
112
  per_device_eval_batch_size: int = field(
113
  default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
114
  )
115
- block_size: int = field(default=None, metadata={"help": "???"})
 
 
 
 
 
 
 
 
116
  learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
117
  weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
118
  adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
@@ -351,7 +359,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
351
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
352
  Shuffle batches if `shuffle` is `True`.
353
  """
354
- steps_per_epoch = len(dataset) // batch_size
355
 
356
  if shuffle:
357
  batch_idx = jax.random.permutation(rng, len(dataset))
@@ -359,7 +367,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
359
  else:
360
  batch_idx = np.arange(len(dataset))
361
 
362
- for idx in range(steps_per_epoch):
363
 
364
  start_idx = batch_size * idx
365
  end_idx = batch_size * (idx + 1)
@@ -661,37 +669,31 @@ def main():
661
  "You can do it from another script, save it, and load it from here, using --feature_extractor_name."
662
  )
663
 
664
- def get_tokenizer():
665
-
666
- tokenizer = None
667
- if model_args.tokenizer_name:
 
 
 
668
  tokenizer = AutoTokenizer.from_pretrained(
669
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
670
  )
671
- elif model_args.model_name_or_path:
672
- try:
673
- tokenizer = AutoTokenizer.from_pretrained(
674
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
675
- )
676
- except ValueError as e:
677
- logger.warning(e)
678
-
679
- # Check decoder
680
- if not tokenizer:
681
- if model_args.decoder_model_name_or_path:
682
- tokenizer = AutoTokenizer.from_pretrained(
683
- model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
684
- )
685
- else:
686
- raise ValueError(
687
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
688
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
689
- )
690
- tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
691
-
692
- return tokenizer
693
 
694
- tokenizer = get_tokenizer()
 
 
 
 
 
 
 
 
 
 
 
695
 
696
  # Preprocessing the datasets.
697
  # We need to tokenize inputs and targets.
@@ -864,8 +866,6 @@ def main():
864
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
865
  )
866
 
867
- tokenizer = get_tokenizer()
868
-
869
  # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
870
  # data loader separately (in a sequential order).
871
  block_size = training_args.block_size
@@ -892,18 +892,26 @@ def main():
892
  num_test_examples = len(predict_dataset)
893
  test_steps = num_test_examples // eval_batch_size
894
 
895
- def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, keep_in_memory=False, split=""):
 
 
 
 
 
 
 
 
896
 
897
  if not block_size:
898
  block_size = len(ds)
899
 
900
- steps_per_split = block_size // batch_size
901
  num_examples = len(ds)
902
  steps = num_examples // batch_size
903
- num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
904
 
905
  if shuffle:
906
- indices = jax.random.permutation(rng, len(train_dataset))
907
  indices = np.asarray(indices)
908
  else:
909
  indices = np.arange(len(ds))
@@ -1131,7 +1139,7 @@ def main():
1131
  if not os.path.isdir(os.path.join(training_args.output_dir)):
1132
  os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
1133
 
1134
- def save_results(epoch, step):
1135
 
1136
  # save checkpoint after each epoch and push checkpoint to the hub
1137
  if jax.process_index() == 0:
@@ -1143,7 +1151,7 @@ def main():
1143
  commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
1144
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1145
 
1146
- def evaluation_loop(rng, dataset, split):
1147
 
1148
  if split not in ["valid", "test"]:
1149
  raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
@@ -1239,10 +1247,10 @@ def main():
1239
  with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
1240
  json.dump(generations, fp, ensure_ascii=False, indent=4)
1241
 
1242
- def evaluate(rng, dataset):
1243
  evaluation_loop(rng, dataset, split='valid')
1244
 
1245
- def predict(rng, dataset):
1246
  evaluation_loop(rng, dataset, split='test')
1247
 
1248
  input_rng = None
@@ -1292,7 +1300,8 @@ def main():
1292
  if has_tensorboard and jax.process_index() == 0:
1293
  write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1294
 
1295
- # ======================== Evaluating ==============================
 
1296
  if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
1297
  evaluate(input_rng, eval_dataset)
1298
  save_results(epoch, cur_step)
@@ -1311,6 +1320,8 @@ def main():
1311
  if has_tensorboard and jax.process_index() == 0:
1312
  write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1313
 
 
 
1314
  if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
1315
  evaluate(input_rng, eval_dataset)
1316
  save_results(epoch, cur_step)
 
112
  per_device_eval_batch_size: int = field(
113
  default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
114
  )
115
+ _block_size_doc = \
116
+ """
117
+ Split a dataset into chunks of size `block_size`. On each block, images are transformed by the feature extractor
118
+ and are kept in memory, and the batches of size `batch_size` are yield before processing the next block.
119
+ """
120
+ block_size: int = field(
121
+ default=64,
122
+ metadata={"help": _block_size_doc}
123
+ )
124
  learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
125
  weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
126
  adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
 
359
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
360
  Shuffle batches if `shuffle` is `True`.
361
  """
362
+ steps = len(dataset) // batch_size
363
 
364
  if shuffle:
365
  batch_idx = jax.random.permutation(rng, len(dataset))
 
367
  else:
368
  batch_idx = np.arange(len(dataset))
369
 
370
+ for idx in range(steps):
371
 
372
  start_idx = batch_size * idx
373
  end_idx = batch_size * (idx + 1)
 
669
  "You can do it from another script, save it, and load it from here, using --feature_extractor_name."
670
  )
671
 
672
+ tokenizer = None
673
+ if model_args.tokenizer_name:
674
+ tokenizer = AutoTokenizer.from_pretrained(
675
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
676
+ )
677
+ elif model_args.model_name_or_path:
678
+ try:
679
  tokenizer = AutoTokenizer.from_pretrained(
680
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
681
  )
682
+ except ValueError as e:
683
+ logger.warning(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
 
685
+ # Check decoder
686
+ if not tokenizer:
687
+ if model_args.decoder_model_name_or_path:
688
+ tokenizer = AutoTokenizer.from_pretrained(
689
+ model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
690
+ )
691
+ else:
692
+ raise ValueError(
693
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
694
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
695
+ )
696
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
697
 
698
  # Preprocessing the datasets.
699
  # We need to tokenize inputs and targets.
 
866
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
867
  )
868
 
 
 
869
  # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
870
  # data loader separately (in a sequential order).
871
  block_size = training_args.block_size
 
892
  num_test_examples = len(predict_dataset)
893
  test_steps = num_test_examples // eval_batch_size
894
 
895
+ def get_batch_iter(
896
+ rng: jax.random.PRNGKey,
897
+ ds: Dataset,
898
+ block_size: int,
899
+ batch_size: int,
900
+ shuffle: bool = False,
901
+ keep_in_memory: bool = False,
902
+ split: str = ""
903
+ ):
904
 
905
  if not block_size:
906
  block_size = len(ds)
907
 
908
+ steps_per_block = block_size // batch_size
909
  num_examples = len(ds)
910
  steps = num_examples // batch_size
911
+ num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
912
 
913
  if shuffle:
914
+ indices = jax.random.permutation(rng, len(ds))
915
  indices = np.asarray(indices)
916
  else:
917
  indices = np.arange(len(ds))
 
1139
  if not os.path.isdir(os.path.join(training_args.output_dir)):
1140
  os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
1141
 
1142
+ def save_results(epoch: int, step: int):
1143
 
1144
  # save checkpoint after each epoch and push checkpoint to the hub
1145
  if jax.process_index() == 0:
 
1151
  commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
1152
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1153
 
1154
+ def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, split: str):
1155
 
1156
  if split not in ["valid", "test"]:
1157
  raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
 
1247
  with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
1248
  json.dump(generations, fp, ensure_ascii=False, indent=4)
1249
 
1250
+ def evaluate(rng: jax.random.PRNGKey, dataset: Dataset):
1251
  evaluation_loop(rng, dataset, split='valid')
1252
 
1253
+ def predict(rng: jax.random.PRNGKey, dataset: Dataset):
1254
  evaluation_loop(rng, dataset, split='test')
1255
 
1256
  input_rng = None
 
1300
  if has_tensorboard and jax.process_index() == 0:
1301
  write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1302
 
1303
+ # ======================== Evaluating (inside epoch) ==============================
1304
+
1305
  if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
1306
  evaluate(input_rng, eval_dataset)
1307
  save_results(epoch, cur_step)
 
1320
  if has_tensorboard and jax.process_index() == 0:
1321
  write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1322
 
1323
+ # ======================== Evaluating (after each epoch) ==============================
1324
+
1325
  if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
1326
  evaluate(input_rng, eval_dataset)
1327
  save_results(epoch, cur_step)