ydshieh
commited on
Commit
·
eabb817
1
Parent(s):
a897ce1
make block_size 0 work
Browse files- run_image_captioning_flax.py +58 -34
run_image_captioning_flax.py
CHANGED
|
@@ -785,6 +785,14 @@ def main():
|
|
| 785 |
|
| 786 |
return model_inputs
|
| 787 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 788 |
features = datasets.Features(
|
| 789 |
{
|
| 790 |
"pixel_values": datasets.Array3D(
|
|
@@ -801,6 +809,10 @@ def main():
|
|
| 801 |
}
|
| 802 |
)
|
| 803 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
if training_args.do_train:
|
| 805 |
if "train" not in dataset:
|
| 806 |
raise ValueError("--do_train requires a train dataset")
|
|
@@ -810,15 +822,18 @@ def main():
|
|
| 810 |
if data_args.max_train_samples is not None:
|
| 811 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
| 812 |
train_dataset = train_dataset.map(
|
| 813 |
-
|
| 814 |
batched=True,
|
| 815 |
num_proc=data_args.preprocessing_num_workers,
|
| 816 |
# kept image paths
|
| 817 |
-
remove_columns=
|
| 818 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 819 |
desc=f"Running tokenizer on train dataset",
|
| 820 |
fn_kwargs={"max_target_length": data_args.max_target_length},
|
|
|
|
| 821 |
)
|
|
|
|
|
|
|
| 822 |
|
| 823 |
if training_args.do_eval:
|
| 824 |
if "validation" not in dataset:
|
|
@@ -829,15 +844,18 @@ def main():
|
|
| 829 |
if data_args.max_eval_samples is not None:
|
| 830 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
| 831 |
eval_dataset = eval_dataset.map(
|
| 832 |
-
|
| 833 |
batched=True,
|
| 834 |
num_proc=data_args.preprocessing_num_workers,
|
| 835 |
# kept image paths
|
| 836 |
-
remove_columns=
|
| 837 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 838 |
desc=f"Running tokenizer on validation dataset",
|
| 839 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
|
|
|
| 840 |
)
|
|
|
|
|
|
|
| 841 |
|
| 842 |
if training_args.do_predict:
|
| 843 |
if "test" not in dataset:
|
|
@@ -848,15 +866,18 @@ def main():
|
|
| 848 |
if data_args.max_predict_samples is not None:
|
| 849 |
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
| 850 |
predict_dataset = predict_dataset.map(
|
| 851 |
-
|
| 852 |
batched=True,
|
| 853 |
num_proc=data_args.preprocessing_num_workers,
|
| 854 |
# kept image paths
|
| 855 |
-
remove_columns=
|
| 856 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 857 |
desc=f"Running tokenizer on prediction dataset",
|
| 858 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
|
|
|
| 859 |
)
|
|
|
|
|
|
|
| 860 |
|
| 861 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
| 862 |
# data loader separately (in a sequential order).
|
|
@@ -894,46 +915,49 @@ def main():
|
|
| 894 |
split: str = ""
|
| 895 |
):
|
| 896 |
|
| 897 |
-
if not block_size:
|
| 898 |
-
block_size = len(ds)
|
| 899 |
-
|
| 900 |
-
steps_per_block = block_size // batch_size
|
| 901 |
-
num_examples = len(ds)
|
| 902 |
-
steps = num_examples // batch_size
|
| 903 |
-
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
| 904 |
-
|
| 905 |
if shuffle:
|
| 906 |
indices = jax.random.permutation(rng, len(ds))
|
| 907 |
indices = np.asarray(indices)
|
| 908 |
else:
|
| 909 |
indices = np.arange(len(ds))
|
| 910 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 911 |
for idx in range(num_splits):
|
| 912 |
|
| 913 |
-
|
| 914 |
-
|
|
|
|
| 915 |
|
| 916 |
-
|
|
|
|
| 917 |
|
| 918 |
-
|
| 919 |
|
| 920 |
-
|
| 921 |
-
"train": "train",
|
| 922 |
-
"valid": "validation",
|
| 923 |
-
"test": "prediction",
|
| 924 |
-
}
|
| 925 |
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
|
| 938 |
# No need to shuffle here
|
| 939 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|
|
|
|
| 785 |
|
| 786 |
return model_inputs
|
| 787 |
|
| 788 |
+
def preprocess_fn(examples, max_target_length):
|
| 789 |
+
|
| 790 |
+
model_inputs = {}
|
| 791 |
+
model_inputs.update(tokenization_fn(examples, max_target_length))
|
| 792 |
+
model_inputs.update(feature_extraction_fn(model_inputs))
|
| 793 |
+
|
| 794 |
+
return model_inputs
|
| 795 |
+
|
| 796 |
features = datasets.Features(
|
| 797 |
{
|
| 798 |
"pixel_values": datasets.Array3D(
|
|
|
|
| 809 |
}
|
| 810 |
)
|
| 811 |
|
| 812 |
+
function_kwarg = preprocess_fn if not training_args.block_size else tokenization_fn
|
| 813 |
+
features_kwarg = features if not training_args.block_size else None
|
| 814 |
+
remove_columns_kwarg = [x for x in column_names if x != image_column or not training_args.block_size]
|
| 815 |
+
|
| 816 |
if training_args.do_train:
|
| 817 |
if "train" not in dataset:
|
| 818 |
raise ValueError("--do_train requires a train dataset")
|
|
|
|
| 822 |
if data_args.max_train_samples is not None:
|
| 823 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
| 824 |
train_dataset = train_dataset.map(
|
| 825 |
+
function=function_kwarg,
|
| 826 |
batched=True,
|
| 827 |
num_proc=data_args.preprocessing_num_workers,
|
| 828 |
# kept image paths
|
| 829 |
+
remove_columns=remove_columns_kwarg,
|
| 830 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 831 |
desc=f"Running tokenizer on train dataset",
|
| 832 |
fn_kwargs={"max_target_length": data_args.max_target_length},
|
| 833 |
+
features=features_kwarg,
|
| 834 |
)
|
| 835 |
+
if not training_args.block_size:
|
| 836 |
+
train_dataset = train_dataset.with_format("numpy")
|
| 837 |
|
| 838 |
if training_args.do_eval:
|
| 839 |
if "validation" not in dataset:
|
|
|
|
| 844 |
if data_args.max_eval_samples is not None:
|
| 845 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
| 846 |
eval_dataset = eval_dataset.map(
|
| 847 |
+
function=function_kwarg,
|
| 848 |
batched=True,
|
| 849 |
num_proc=data_args.preprocessing_num_workers,
|
| 850 |
# kept image paths
|
| 851 |
+
remove_columns=remove_columns_kwarg,
|
| 852 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 853 |
desc=f"Running tokenizer on validation dataset",
|
| 854 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
| 855 |
+
features=features_kwarg,
|
| 856 |
)
|
| 857 |
+
if not training_args.block_size:
|
| 858 |
+
eval_dataset = eval_dataset.with_format("numpy")
|
| 859 |
|
| 860 |
if training_args.do_predict:
|
| 861 |
if "test" not in dataset:
|
|
|
|
| 866 |
if data_args.max_predict_samples is not None:
|
| 867 |
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
| 868 |
predict_dataset = predict_dataset.map(
|
| 869 |
+
function=function_kwarg,
|
| 870 |
batched=True,
|
| 871 |
num_proc=data_args.preprocessing_num_workers,
|
| 872 |
# kept image paths
|
| 873 |
+
remove_columns=remove_columns_kwarg,
|
| 874 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 875 |
desc=f"Running tokenizer on prediction dataset",
|
| 876 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
| 877 |
+
features=features_kwarg,
|
| 878 |
)
|
| 879 |
+
if not training_args.block_size:
|
| 880 |
+
predict_dataset = predict_dataset.with_format("numpy")
|
| 881 |
|
| 882 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
| 883 |
# data loader separately (in a sequential order).
|
|
|
|
| 915 |
split: str = ""
|
| 916 |
):
|
| 917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
if shuffle:
|
| 919 |
indices = jax.random.permutation(rng, len(ds))
|
| 920 |
indices = np.asarray(indices)
|
| 921 |
else:
|
| 922 |
indices = np.arange(len(ds))
|
| 923 |
|
| 924 |
+
_block_size = len(ds) if not block_size else block_size
|
| 925 |
+
|
| 926 |
+
steps_per_block = _block_size // batch_size
|
| 927 |
+
num_examples = len(ds)
|
| 928 |
+
steps = num_examples // batch_size
|
| 929 |
+
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
| 930 |
+
|
| 931 |
for idx in range(num_splits):
|
| 932 |
|
| 933 |
+
if not block_size:
|
| 934 |
+
_ds = ds
|
| 935 |
+
else:
|
| 936 |
|
| 937 |
+
start_idx = block_size * idx
|
| 938 |
+
end_idx = block_size * (idx + 1)
|
| 939 |
|
| 940 |
+
selected_indices = indices[start_idx:end_idx]
|
| 941 |
|
| 942 |
+
_ds = ds.select(selected_indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
|
| 944 |
+
names = {
|
| 945 |
+
"train": "train",
|
| 946 |
+
"valid": "validation",
|
| 947 |
+
"test": "prediction",
|
| 948 |
+
}
|
| 949 |
+
|
| 950 |
+
_ds = _ds.map(
|
| 951 |
+
feature_extraction_fn,
|
| 952 |
+
batched=True,
|
| 953 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 954 |
+
remove_columns=[image_column],
|
| 955 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
| 956 |
+
features=features,
|
| 957 |
+
keep_in_memory=keep_in_memory,
|
| 958 |
+
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
| 959 |
+
)
|
| 960 |
+
_ds = _ds.with_format("numpy")
|
| 961 |
|
| 962 |
# No need to shuffle here
|
| 963 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|