ydshieh
commited on
Commit
·
2eb0333
1
Parent(s):
10a974e
fix
Browse files- run_image_captioning_flax.py +21 -15
run_image_captioning_flax.py
CHANGED
@@ -835,22 +835,23 @@ def main():
|
|
835 |
}
|
836 |
)
|
837 |
|
838 |
-
# If `block_size` is `0`, tokenization & image feature extraction is done
|
839 |
-
|
840 |
# Used in .map() below
|
841 |
-
function_kwarg = preprocess_fn if
|
842 |
# `features` is used only for the final preprocessed dataset (for the performance purpose).
|
843 |
-
features_kwarg = features if
|
844 |
# Keep `image_column` if the feature extraction is done during training
|
845 |
-
remove_columns_kwarg = [x for x in column_names if x != image_column or
|
846 |
-
processor_names = "tokenizer and feature extractor" if
|
847 |
|
848 |
if training_args.do_train:
|
849 |
if "train" not in dataset:
|
850 |
raise ValueError("--do_train requires a train dataset")
|
851 |
train_dataset = dataset["train"]
|
852 |
-
# remove problematic examples
|
853 |
-
if not
|
|
|
854 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
855 |
if data_args.max_train_samples is not None:
|
856 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
@@ -865,15 +866,17 @@ def main():
|
|
865 |
fn_kwargs={"max_target_length": data_args.max_target_length},
|
866 |
features=features_kwarg,
|
867 |
)
|
868 |
-
if
|
|
|
869 |
train_dataset = train_dataset.with_format("numpy")
|
870 |
|
871 |
if training_args.do_eval:
|
872 |
if "validation" not in dataset:
|
873 |
raise ValueError("--do_eval requires a validation dataset")
|
874 |
eval_dataset = dataset["validation"]
|
875 |
-
# remove problematic examples
|
876 |
-
if not
|
|
|
877 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
878 |
if data_args.max_eval_samples is not None:
|
879 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
@@ -888,15 +891,17 @@ def main():
|
|
888 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
889 |
features=features_kwarg,
|
890 |
)
|
891 |
-
if
|
|
|
892 |
eval_dataset = eval_dataset.with_format("numpy")
|
893 |
|
894 |
if training_args.do_predict:
|
895 |
if "test" not in dataset:
|
896 |
raise ValueError("--do_predict requires a test dataset")
|
897 |
predict_dataset = dataset["test"]
|
898 |
-
# remove problematic examples
|
899 |
-
if not
|
|
|
900 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
901 |
if data_args.max_predict_samples is not None:
|
902 |
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
@@ -911,7 +916,8 @@ def main():
|
|
911 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
912 |
features=features_kwarg,
|
913 |
)
|
914 |
-
if
|
|
|
915 |
predict_dataset = predict_dataset.with_format("numpy")
|
916 |
|
917 |
# Store some constant
|
|
|
835 |
}
|
836 |
)
|
837 |
|
838 |
+
# If `block_size` is `0`, tokenization & image feature extraction is done at the beginning
|
839 |
+
run_feat_ext_at_beginning = training_args.block_size == 0
|
840 |
# Used in .map() below
|
841 |
+
function_kwarg = preprocess_fn if run_feat_ext_at_beginning else tokenization_fn
|
842 |
# `features` is used only for the final preprocessed dataset (for the performance purpose).
|
843 |
+
features_kwarg = features if run_feat_ext_at_beginning else None
|
844 |
# Keep `image_column` if the feature extraction is done during training
|
845 |
+
remove_columns_kwarg = [x for x in column_names if x != image_column or run_feat_ext_at_beginning]
|
846 |
+
processor_names = "tokenizer and feature extractor" if run_feat_ext_at_beginning else "tokenizer"
|
847 |
|
848 |
if training_args.do_train:
|
849 |
if "train" not in dataset:
|
850 |
raise ValueError("--do_train requires a train dataset")
|
851 |
train_dataset = dataset["train"]
|
852 |
+
# remove problematic examples
|
853 |
+
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
854 |
+
if not run_feat_ext_at_beginning:
|
855 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
856 |
if data_args.max_train_samples is not None:
|
857 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
|
|
866 |
fn_kwargs={"max_target_length": data_args.max_target_length},
|
867 |
features=features_kwarg,
|
868 |
)
|
869 |
+
if run_feat_ext_at_beginning:
|
870 |
+
# set format (for performance) since the dataset is ready to be used
|
871 |
train_dataset = train_dataset.with_format("numpy")
|
872 |
|
873 |
if training_args.do_eval:
|
874 |
if "validation" not in dataset:
|
875 |
raise ValueError("--do_eval requires a validation dataset")
|
876 |
eval_dataset = dataset["validation"]
|
877 |
+
# remove problematic examples
|
878 |
+
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
879 |
+
if not run_feat_ext_at_beginning:
|
880 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
881 |
if data_args.max_eval_samples is not None:
|
882 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
|
|
891 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
892 |
features=features_kwarg,
|
893 |
)
|
894 |
+
if run_feat_ext_at_beginning:
|
895 |
+
# set format (for performance) since the dataset is ready to be used
|
896 |
eval_dataset = eval_dataset.with_format("numpy")
|
897 |
|
898 |
if training_args.do_predict:
|
899 |
if "test" not in dataset:
|
900 |
raise ValueError("--do_predict requires a test dataset")
|
901 |
predict_dataset = dataset["test"]
|
902 |
+
# remove problematic examples
|
903 |
+
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
904 |
+
if not run_feat_ext_at_beginning:
|
905 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
906 |
if data_args.max_predict_samples is not None:
|
907 |
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
|
|
916 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
917 |
features=features_kwarg,
|
918 |
)
|
919 |
+
if run_feat_ext_at_beginning:
|
920 |
+
# set format (for performance) since the dataset is ready to be used
|
921 |
predict_dataset = predict_dataset.with_format("numpy")
|
922 |
|
923 |
# Store some constant
|