ydshieh commited on
Commit
2eb0333
·
1 Parent(s): 10a974e
Files changed (1) hide show
  1. run_image_captioning_flax.py +21 -15
run_image_captioning_flax.py CHANGED
@@ -835,22 +835,23 @@ def main():
835
  }
836
  )
837
 
838
- # If `block_size` is `0`, tokenization & image feature extraction is done before training
839
- run_feat_ext_before_training = training_args.block_size == 0
840
  # Used in .map() below
841
- function_kwarg = preprocess_fn if run_feat_ext_before_training else tokenization_fn
842
  # `features` is used only for the final preprocessed dataset (for the performance purpose).
843
- features_kwarg = features if run_feat_ext_before_training else None
844
  # Keep `image_column` if the feature extraction is done during training
845
- remove_columns_kwarg = [x for x in column_names if x != image_column or run_feat_ext_before_training]
846
- processor_names = "tokenizer and feature extractor" if run_feat_ext_before_training else "tokenizer"
847
 
848
  if training_args.do_train:
849
  if "train" not in dataset:
850
  raise ValueError("--do_train requires a train dataset")
851
  train_dataset = dataset["train"]
852
- # remove problematic examples
853
- if not run_feat_ext_before_training:
 
854
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
855
  if data_args.max_train_samples is not None:
856
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
@@ -865,15 +866,17 @@ def main():
865
  fn_kwargs={"max_target_length": data_args.max_target_length},
866
  features=features_kwarg,
867
  )
868
- if run_feat_ext_before_training:
 
869
  train_dataset = train_dataset.with_format("numpy")
870
 
871
  if training_args.do_eval:
872
  if "validation" not in dataset:
873
  raise ValueError("--do_eval requires a validation dataset")
874
  eval_dataset = dataset["validation"]
875
- # remove problematic examples
876
- if not run_feat_ext_before_training:
 
877
  eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
878
  if data_args.max_eval_samples is not None:
879
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
@@ -888,15 +891,17 @@ def main():
888
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
889
  features=features_kwarg,
890
  )
891
- if run_feat_ext_before_training:
 
892
  eval_dataset = eval_dataset.with_format("numpy")
893
 
894
  if training_args.do_predict:
895
  if "test" not in dataset:
896
  raise ValueError("--do_predict requires a test dataset")
897
  predict_dataset = dataset["test"]
898
- # remove problematic examples
899
- if not run_feat_ext_before_training:
 
900
  predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
901
  if data_args.max_predict_samples is not None:
902
  predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
@@ -911,7 +916,8 @@ def main():
911
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
912
  features=features_kwarg,
913
  )
914
- if run_feat_ext_before_training:
 
915
  predict_dataset = predict_dataset.with_format("numpy")
916
 
917
  # Store some constant
 
835
  }
836
  )
837
 
838
+ # If `block_size` is `0`, tokenization & image feature extraction is done at the beginning
839
+ run_feat_ext_at_beginning = training_args.block_size == 0
840
  # Used in .map() below
841
+ function_kwarg = preprocess_fn if run_feat_ext_at_beginning else tokenization_fn
842
  # `features` is used only for the final preprocessed dataset (for the performance purpose).
843
+ features_kwarg = features if run_feat_ext_at_beginning else None
844
  # Keep `image_column` if the feature extraction is done during training
845
+ remove_columns_kwarg = [x for x in column_names if x != image_column or run_feat_ext_at_beginning]
846
+ processor_names = "tokenizer and feature extractor" if run_feat_ext_at_beginning else "tokenizer"
847
 
848
  if training_args.do_train:
849
  if "train" not in dataset:
850
  raise ValueError("--do_train requires a train dataset")
851
  train_dataset = dataset["train"]
852
+ # remove problematic examples
853
+ # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
854
+ if not run_feat_ext_at_beginning:
855
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
856
  if data_args.max_train_samples is not None:
857
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
 
866
  fn_kwargs={"max_target_length": data_args.max_target_length},
867
  features=features_kwarg,
868
  )
869
+ if run_feat_ext_at_beginning:
870
+ # set format (for performance) since the dataset is ready to be used
871
  train_dataset = train_dataset.with_format("numpy")
872
 
873
  if training_args.do_eval:
874
  if "validation" not in dataset:
875
  raise ValueError("--do_eval requires a validation dataset")
876
  eval_dataset = dataset["validation"]
877
+ # remove problematic examples
878
+ # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
879
+ if not run_feat_ext_at_beginning:
880
  eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
881
  if data_args.max_eval_samples is not None:
882
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
 
891
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
892
  features=features_kwarg,
893
  )
894
+ if run_feat_ext_at_beginning:
895
+ # set format (for performance) since the dataset is ready to be used
896
  eval_dataset = eval_dataset.with_format("numpy")
897
 
898
  if training_args.do_predict:
899
  if "test" not in dataset:
900
  raise ValueError("--do_predict requires a test dataset")
901
  predict_dataset = dataset["test"]
902
+ # remove problematic examples
903
+ # (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
904
+ if not run_feat_ext_at_beginning:
905
  predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
906
  if data_args.max_predict_samples is not None:
907
  predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
 
916
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
917
  features=features_kwarg,
918
  )
919
+ if run_feat_ext_at_beginning:
920
+ # set format (for performance) since the dataset is ready to be used
921
  predict_dataset = predict_dataset.with_format("numpy")
922
 
923
  # Store some constant