ydshieh
commited on
Commit
·
8a4e070
1
Parent(s):
2eb0333
fix
Browse files
run_image_captioning_flax.py
CHANGED
@@ -849,12 +849,12 @@ def main():
|
|
849 |
if "train" not in dataset:
|
850 |
raise ValueError("--do_train requires a train dataset")
|
851 |
train_dataset = dataset["train"]
|
852 |
-
|
|
|
|
|
853 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
854 |
if not run_feat_ext_at_beginning:
|
855 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
856 |
-
if data_args.max_train_samples is not None:
|
857 |
-
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
858 |
train_dataset = train_dataset.map(
|
859 |
function=function_kwarg,
|
860 |
batched=True,
|
@@ -874,12 +874,12 @@ def main():
|
|
874 |
if "validation" not in dataset:
|
875 |
raise ValueError("--do_eval requires a validation dataset")
|
876 |
eval_dataset = dataset["validation"]
|
877 |
-
|
|
|
|
|
878 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
879 |
if not run_feat_ext_at_beginning:
|
880 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
881 |
-
if data_args.max_eval_samples is not None:
|
882 |
-
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
883 |
eval_dataset = eval_dataset.map(
|
884 |
function=function_kwarg,
|
885 |
batched=True,
|
@@ -899,12 +899,12 @@ def main():
|
|
899 |
if "test" not in dataset:
|
900 |
raise ValueError("--do_predict requires a test dataset")
|
901 |
predict_dataset = dataset["test"]
|
902 |
-
|
|
|
|
|
903 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
904 |
if not run_feat_ext_at_beginning:
|
905 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
906 |
-
if data_args.max_predict_samples is not None:
|
907 |
-
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
908 |
predict_dataset = predict_dataset.map(
|
909 |
function=function_kwarg,
|
910 |
batched=True,
|
|
|
849 |
if "train" not in dataset:
|
850 |
raise ValueError("--do_train requires a train dataset")
|
851 |
train_dataset = dataset["train"]
|
852 |
+
if data_args.max_train_samples is not None:
|
853 |
+
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
854 |
+
# remove problematic examples
|
855 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
856 |
if not run_feat_ext_at_beginning:
|
857 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
|
|
|
|
858 |
train_dataset = train_dataset.map(
|
859 |
function=function_kwarg,
|
860 |
batched=True,
|
|
|
874 |
if "validation" not in dataset:
|
875 |
raise ValueError("--do_eval requires a validation dataset")
|
876 |
eval_dataset = dataset["validation"]
|
877 |
+
if data_args.max_eval_samples is not None:
|
878 |
+
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
879 |
+
# remove problematic examples
|
880 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
881 |
if not run_feat_ext_at_beginning:
|
882 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
|
|
|
|
883 |
eval_dataset = eval_dataset.map(
|
884 |
function=function_kwarg,
|
885 |
batched=True,
|
|
|
899 |
if "test" not in dataset:
|
900 |
raise ValueError("--do_predict requires a test dataset")
|
901 |
predict_dataset = dataset["test"]
|
902 |
+
if data_args.max_predict_samples is not None:
|
903 |
+
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
904 |
+
# remove problematic examples
|
905 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing not here)
|
906 |
if not run_feat_ext_at_beginning:
|
907 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
|
|
|
|
908 |
predict_dataset = predict_dataset.map(
|
909 |
function=function_kwarg,
|
910 |
batched=True,
|