ydshieh
commited on
Commit
·
eabb817
1
Parent(s):
a897ce1
make block_size 0 work
Browse files- run_image_captioning_flax.py +58 -34
run_image_captioning_flax.py
CHANGED
@@ -785,6 +785,14 @@ def main():
|
|
785 |
|
786 |
return model_inputs
|
787 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
788 |
features = datasets.Features(
|
789 |
{
|
790 |
"pixel_values": datasets.Array3D(
|
@@ -801,6 +809,10 @@ def main():
|
|
801 |
}
|
802 |
)
|
803 |
|
|
|
|
|
|
|
|
|
804 |
if training_args.do_train:
|
805 |
if "train" not in dataset:
|
806 |
raise ValueError("--do_train requires a train dataset")
|
@@ -810,15 +822,18 @@ def main():
|
|
810 |
if data_args.max_train_samples is not None:
|
811 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
812 |
train_dataset = train_dataset.map(
|
813 |
-
|
814 |
batched=True,
|
815 |
num_proc=data_args.preprocessing_num_workers,
|
816 |
# kept image paths
|
817 |
-
remove_columns=
|
818 |
load_from_cache_file=not data_args.overwrite_cache,
|
819 |
desc=f"Running tokenizer on train dataset",
|
820 |
fn_kwargs={"max_target_length": data_args.max_target_length},
|
|
|
821 |
)
|
|
|
|
|
822 |
|
823 |
if training_args.do_eval:
|
824 |
if "validation" not in dataset:
|
@@ -829,15 +844,18 @@ def main():
|
|
829 |
if data_args.max_eval_samples is not None:
|
830 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
831 |
eval_dataset = eval_dataset.map(
|
832 |
-
|
833 |
batched=True,
|
834 |
num_proc=data_args.preprocessing_num_workers,
|
835 |
# kept image paths
|
836 |
-
remove_columns=
|
837 |
load_from_cache_file=not data_args.overwrite_cache,
|
838 |
desc=f"Running tokenizer on validation dataset",
|
839 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
|
|
840 |
)
|
|
|
|
|
841 |
|
842 |
if training_args.do_predict:
|
843 |
if "test" not in dataset:
|
@@ -848,15 +866,18 @@ def main():
|
|
848 |
if data_args.max_predict_samples is not None:
|
849 |
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
850 |
predict_dataset = predict_dataset.map(
|
851 |
-
|
852 |
batched=True,
|
853 |
num_proc=data_args.preprocessing_num_workers,
|
854 |
# kept image paths
|
855 |
-
remove_columns=
|
856 |
load_from_cache_file=not data_args.overwrite_cache,
|
857 |
desc=f"Running tokenizer on prediction dataset",
|
858 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
|
|
859 |
)
|
|
|
|
|
860 |
|
861 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
862 |
# data loader separately (in a sequential order).
|
@@ -894,46 +915,49 @@ def main():
|
|
894 |
split: str = ""
|
895 |
):
|
896 |
|
897 |
-
if not block_size:
|
898 |
-
block_size = len(ds)
|
899 |
-
|
900 |
-
steps_per_block = block_size // batch_size
|
901 |
-
num_examples = len(ds)
|
902 |
-
steps = num_examples // batch_size
|
903 |
-
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
904 |
-
|
905 |
if shuffle:
|
906 |
indices = jax.random.permutation(rng, len(ds))
|
907 |
indices = np.asarray(indices)
|
908 |
else:
|
909 |
indices = np.arange(len(ds))
|
910 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
for idx in range(num_splits):
|
912 |
|
913 |
-
|
914 |
-
|
|
|
915 |
|
916 |
-
|
|
|
917 |
|
918 |
-
|
919 |
|
920 |
-
|
921 |
-
"train": "train",
|
922 |
-
"valid": "validation",
|
923 |
-
"test": "prediction",
|
924 |
-
}
|
925 |
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
937 |
|
938 |
# No need to shuffle here
|
939 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|
|
|
785 |
|
786 |
return model_inputs
|
787 |
|
788 |
+
def preprocess_fn(examples, max_target_length):
|
789 |
+
|
790 |
+
model_inputs = {}
|
791 |
+
model_inputs.update(tokenization_fn(examples, max_target_length))
|
792 |
+
model_inputs.update(feature_extraction_fn(model_inputs))
|
793 |
+
|
794 |
+
return model_inputs
|
795 |
+
|
796 |
features = datasets.Features(
|
797 |
{
|
798 |
"pixel_values": datasets.Array3D(
|
|
|
809 |
}
|
810 |
)
|
811 |
|
812 |
+
function_kwarg = preprocess_fn if not training_args.block_size else tokenization_fn
|
813 |
+
features_kwarg = features if not training_args.block_size else None
|
814 |
+
remove_columns_kwarg = [x for x in column_names if x != image_column or not training_args.block_size]
|
815 |
+
|
816 |
if training_args.do_train:
|
817 |
if "train" not in dataset:
|
818 |
raise ValueError("--do_train requires a train dataset")
|
|
|
822 |
if data_args.max_train_samples is not None:
|
823 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
824 |
train_dataset = train_dataset.map(
|
825 |
+
function=function_kwarg,
|
826 |
batched=True,
|
827 |
num_proc=data_args.preprocessing_num_workers,
|
828 |
# kept image paths
|
829 |
+
remove_columns=remove_columns_kwarg,
|
830 |
load_from_cache_file=not data_args.overwrite_cache,
|
831 |
desc=f"Running tokenizer on train dataset",
|
832 |
fn_kwargs={"max_target_length": data_args.max_target_length},
|
833 |
+
features=features_kwarg,
|
834 |
)
|
835 |
+
if not training_args.block_size:
|
836 |
+
train_dataset = train_dataset.with_format("numpy")
|
837 |
|
838 |
if training_args.do_eval:
|
839 |
if "validation" not in dataset:
|
|
|
844 |
if data_args.max_eval_samples is not None:
|
845 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
846 |
eval_dataset = eval_dataset.map(
|
847 |
+
function=function_kwarg,
|
848 |
batched=True,
|
849 |
num_proc=data_args.preprocessing_num_workers,
|
850 |
# kept image paths
|
851 |
+
remove_columns=remove_columns_kwarg,
|
852 |
load_from_cache_file=not data_args.overwrite_cache,
|
853 |
desc=f"Running tokenizer on validation dataset",
|
854 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
855 |
+
features=features_kwarg,
|
856 |
)
|
857 |
+
if not training_args.block_size:
|
858 |
+
eval_dataset = eval_dataset.with_format("numpy")
|
859 |
|
860 |
if training_args.do_predict:
|
861 |
if "test" not in dataset:
|
|
|
866 |
if data_args.max_predict_samples is not None:
|
867 |
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
868 |
predict_dataset = predict_dataset.map(
|
869 |
+
function=function_kwarg,
|
870 |
batched=True,
|
871 |
num_proc=data_args.preprocessing_num_workers,
|
872 |
# kept image paths
|
873 |
+
remove_columns=remove_columns_kwarg,
|
874 |
load_from_cache_file=not data_args.overwrite_cache,
|
875 |
desc=f"Running tokenizer on prediction dataset",
|
876 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
877 |
+
features=features_kwarg,
|
878 |
)
|
879 |
+
if not training_args.block_size:
|
880 |
+
predict_dataset = predict_dataset.with_format("numpy")
|
881 |
|
882 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
883 |
# data loader separately (in a sequential order).
|
|
|
915 |
split: str = ""
|
916 |
):
|
917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
918 |
if shuffle:
|
919 |
indices = jax.random.permutation(rng, len(ds))
|
920 |
indices = np.asarray(indices)
|
921 |
else:
|
922 |
indices = np.arange(len(ds))
|
923 |
|
924 |
+
_block_size = len(ds) if not block_size else block_size
|
925 |
+
|
926 |
+
steps_per_block = _block_size // batch_size
|
927 |
+
num_examples = len(ds)
|
928 |
+
steps = num_examples // batch_size
|
929 |
+
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
930 |
+
|
931 |
for idx in range(num_splits):
|
932 |
|
933 |
+
if not block_size:
|
934 |
+
_ds = ds
|
935 |
+
else:
|
936 |
|
937 |
+
start_idx = block_size * idx
|
938 |
+
end_idx = block_size * (idx + 1)
|
939 |
|
940 |
+
selected_indices = indices[start_idx:end_idx]
|
941 |
|
942 |
+
_ds = ds.select(selected_indices)
|
|
|
|
|
|
|
|
|
943 |
|
944 |
+
names = {
|
945 |
+
"train": "train",
|
946 |
+
"valid": "validation",
|
947 |
+
"test": "prediction",
|
948 |
+
}
|
949 |
+
|
950 |
+
_ds = _ds.map(
|
951 |
+
feature_extraction_fn,
|
952 |
+
batched=True,
|
953 |
+
num_proc=data_args.preprocessing_num_workers,
|
954 |
+
remove_columns=[image_column],
|
955 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
956 |
+
features=features,
|
957 |
+
keep_in_memory=keep_in_memory,
|
958 |
+
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
959 |
+
)
|
960 |
+
_ds = _ds.with_format("numpy")
|
961 |
|
962 |
# No need to shuffle here
|
963 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|