ydshieh commited on
Commit
283180e
·
1 Parent(s): cb2b435

more general

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +6 -6
run_image_captioning_flax.py CHANGED
@@ -633,6 +633,8 @@ def main():
633
 
634
  return tokenizer
635
 
 
 
636
  # Preprocessing the datasets.
637
  # We need to tokenize inputs and targets.
638
  if training_args.do_train:
@@ -688,8 +690,6 @@ def main():
688
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
689
  def tokenization_fn(examples, max_target_length):
690
 
691
- tokenizer = get_tokenizer()
692
-
693
  captions = []
694
  for caption in examples[caption_column]:
695
  captions.append(caption.lower() + ' ' + tokenizer.eos_token)
@@ -834,7 +834,7 @@ def main():
834
  num_test_examples = len(predict_dataset)
835
  test_steps = num_test_examples // eval_batch_size + int(num_test_examples % eval_batch_size > 0)
836
 
837
- def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, drop_last_batch=False, split=""):
838
 
839
  if not block_size:
840
  block_size = len(ds)
@@ -871,7 +871,7 @@ def main():
871
  remove_columns=[image_column],
872
  load_from_cache_file=not data_args.overwrite_cache,
873
  features=features,
874
- keep_in_memory=True,
875
  desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
876
  )
877
  _ds = _ds.with_format("numpy")
@@ -1095,7 +1095,7 @@ def main():
1095
  preds = []
1096
  labels = []
1097
 
1098
- batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, shuffle=False, split=name)
1099
  steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
1100
  for _ in tqdm(range(steps), desc=f"{'Predicting' if name == 'test' else 'Evaluating'}...", position=2, leave=False):
1101
  # Model forward
@@ -1197,7 +1197,7 @@ def main():
1197
 
1198
  train_metrics = []
1199
 
1200
- train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, shuffle=True, drop_last_batch=training_args.dataloader_drop_last, split="train")
1201
 
1202
  # train
1203
  for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
 
633
 
634
  return tokenizer
635
 
636
+ tokenizer = get_tokenizer()
637
+
638
  # Preprocessing the datasets.
639
  # We need to tokenize inputs and targets.
640
  if training_args.do_train:
 
690
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
691
  def tokenization_fn(examples, max_target_length):
692
 
 
 
693
  captions = []
694
  for caption in examples[caption_column]:
695
  captions.append(caption.lower() + ' ' + tokenizer.eos_token)
 
834
  num_test_examples = len(predict_dataset)
835
  test_steps = num_test_examples // eval_batch_size + int(num_test_examples % eval_batch_size > 0)
836
 
837
+ def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, drop_last_batch=False, keep_in_memory=False, split=""):
838
 
839
  if not block_size:
840
  block_size = len(ds)
 
871
  remove_columns=[image_column],
872
  load_from_cache_file=not data_args.overwrite_cache,
873
  features=features,
874
+ keep_in_memory=keep_in_memory,
875
  desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
876
  )
877
  _ds = _ds.with_format("numpy")
 
1095
  preds = []
1096
  labels = []
1097
 
1098
+ batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=name)
1099
  steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
1100
  for _ in tqdm(range(steps), desc=f"{'Predicting' if name == 'test' else 'Evaluating'}...", position=2, leave=False):
1101
  # Model forward
 
1197
 
1198
  train_metrics = []
1199
 
1200
+ train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, drop_last_batch=training_args.dataloader_drop_last, split="train")
1201
 
1202
  # train
1203
  for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):