ydshieh commited on
Commit
8a4e070
·
1 Parent(s): 2eb0333
Files changed (1) hide show
  1. run_image_captioning_flax.py +9 -9
run_image_captioning_flax.py CHANGED
@@ -849,12 +849,12 @@ def main():
849
  if "train" not in dataset:
850
  raise ValueError("--do_train requires a train dataset")
851
  train_dataset = dataset["train"]
852
- # remove problematic examples
 
 
853
  # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
854
  if not run_feat_ext_at_beginning:
855
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
856
- if data_args.max_train_samples is not None:
857
- train_dataset = train_dataset.select(range(data_args.max_train_samples))
858
  train_dataset = train_dataset.map(
859
  function=function_kwarg,
860
  batched=True,
@@ -874,12 +874,12 @@ def main():
874
  if "validation" not in dataset:
875
  raise ValueError("--do_eval requires a validation dataset")
876
  eval_dataset = dataset["validation"]
877
- # remove problematic examples
 
 
878
  # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
879
  if not run_feat_ext_at_beginning:
880
  eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
881
- if data_args.max_eval_samples is not None:
882
- eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
883
  eval_dataset = eval_dataset.map(
884
  function=function_kwarg,
885
  batched=True,
@@ -899,12 +899,12 @@ def main():
899
  if "test" not in dataset:
900
  raise ValueError("--do_predict requires a test dataset")
901
  predict_dataset = dataset["test"]
902
- # remove problematic examples
 
 
903
  # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
904
  if not run_feat_ext_at_beginning:
905
  predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
906
- if data_args.max_predict_samples is not None:
907
- predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
908
  predict_dataset = predict_dataset.map(
909
  function=function_kwarg,
910
  batched=True,
 
849
  if "train" not in dataset:
850
  raise ValueError("--do_train requires a train dataset")
851
  train_dataset = dataset["train"]
852
+ if data_args.max_train_samples is not None:
853
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
854
+ # remove problematic examples
855
  # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
856
  if not run_feat_ext_at_beginning:
857
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
 
 
858
  train_dataset = train_dataset.map(
859
  function=function_kwarg,
860
  batched=True,
 
874
  if "validation" not in dataset:
875
  raise ValueError("--do_eval requires a validation dataset")
876
  eval_dataset = dataset["validation"]
877
+ if data_args.max_eval_samples is not None:
878
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
879
+ # remove problematic examples
880
  # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
881
  if not run_feat_ext_at_beginning:
882
  eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
 
 
883
  eval_dataset = eval_dataset.map(
884
  function=function_kwarg,
885
  batched=True,
 
899
  if "test" not in dataset:
900
  raise ValueError("--do_predict requires a test dataset")
901
  predict_dataset = dataset["test"]
902
+ if data_args.max_predict_samples is not None:
903
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
904
+ # remove problematic examples
905
  # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
906
  if not run_feat_ext_at_beginning:
907
  predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
 
 
908
  predict_dataset = predict_dataset.map(
909
  function=function_kwarg,
910
  batched=True,