ydshieh
commited on
Commit
·
8f31d11
1
Parent(s):
9f6265f
separate tokenization and feature extraction
Browse files- run_image_captioning_flax.py +85 -5
run_image_captioning_flax.py
CHANGED
@@ -680,6 +680,54 @@ def main():
|
|
680 |
|
681 |
return bools
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
684 |
def preprocess_function(examples, max_target_length):
|
685 |
|
@@ -741,6 +789,16 @@ def main():
|
|
741 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
742 |
if data_args.max_train_samples is not None:
|
743 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
|
745 |
if training_args.do_eval:
|
746 |
if "validation" not in dataset:
|
@@ -750,6 +808,16 @@ def main():
|
|
750 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
751 |
if data_args.max_eval_samples is not None:
|
752 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
|
754 |
if training_args.do_predict:
|
755 |
if "test" not in dataset:
|
@@ -759,6 +827,16 @@ def main():
|
|
759 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
760 |
if data_args.max_predict_samples is not None:
|
761 |
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
762 |
|
763 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
764 |
# data loader separately (in a sequential order).
|
@@ -804,7 +882,6 @@ def main():
|
|
804 |
else:
|
805 |
indices = jnp.arange(len(ds))
|
806 |
|
807 |
-
# Temporarily set max_target_length for training or evaluation/prediction.
|
808 |
max_target_length = data_args.max_target_length
|
809 |
if split in ["valid", "test"]:
|
810 |
max_target_length = data_args.val_max_target_length
|
@@ -825,14 +902,17 @@ def main():
|
|
825 |
}
|
826 |
|
827 |
_ds =_ds.map(
|
828 |
-
preprocess_function,
|
|
|
829 |
batched=True,
|
830 |
num_proc=data_args.preprocessing_num_workers,
|
831 |
-
remove_columns=column_names,
|
|
|
832 |
load_from_cache_file=not data_args.overwrite_cache,
|
833 |
features=features,
|
834 |
-
desc=f"Running tokenizer on {names[split]} dataset".replace(" ", " "),
|
835 |
-
|
|
|
836 |
)
|
837 |
_ds = _ds.with_format("numpy")
|
838 |
|
|
|
680 |
|
681 |
return bools
|
682 |
|
683 |
+
def tokenization_fn(examples, max_target_length):
|
684 |
+
|
685 |
+
captions = []
|
686 |
+
for caption in examples[caption_column]:
|
687 |
+
captions.append(caption.lower() + ' ' + tokenizer.eos_token)
|
688 |
+
|
689 |
+
targets = captions
|
690 |
+
|
691 |
+
model_inputs = {}
|
692 |
+
|
693 |
+
# Setup the tokenizer for targets
|
694 |
+
with tokenizer.as_target_tokenizer():
|
695 |
+
labels = tokenizer(
|
696 |
+
targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
|
697 |
+
)
|
698 |
+
|
699 |
+
model_inputs["labels"] = labels["input_ids"]
|
700 |
+
decoder_input_ids = shift_tokens_right_fn(
|
701 |
+
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
702 |
+
)
|
703 |
+
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
704 |
+
|
705 |
+
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
706 |
+
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
707 |
+
|
708 |
+
model_inputs[image_column] = examples[image_column]
|
709 |
+
|
710 |
+
return model_inputs
|
711 |
+
|
712 |
+
def feature_extraction_fn(examples):
|
713 |
+
|
714 |
+
pixel_values = []
|
715 |
+
|
716 |
+
for image_file in examples[image_column]:
|
717 |
+
with Image.open(image_file) as image:
|
718 |
+
try:
|
719 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="np")
|
720 |
+
except:
|
721 |
+
continue
|
722 |
+
pixel_values.append(encoder_inputs.pixel_values)
|
723 |
+
|
724 |
+
pixel_values = np.concatenate(pixel_values)
|
725 |
+
|
726 |
+
model_inputs = examples
|
727 |
+
model_inputs['pixel_values'] = pixel_values
|
728 |
+
|
729 |
+
return model_inputs
|
730 |
+
|
731 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
732 |
def preprocess_function(examples, max_target_length):
|
733 |
|
|
|
789 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
790 |
if data_args.max_train_samples is not None:
|
791 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
792 |
+
train_dataset = train_dataset.map(
|
793 |
+
tokenization_fn,
|
794 |
+
batched=True,
|
795 |
+
num_proc=data_args.preprocessing_num_workers,
|
796 |
+
# kept image paths
|
797 |
+
remove_columns=column_names.remove(image_column),
|
798 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
799 |
+
desc=f"Running tokenizer on train dataset",
|
800 |
+
fn_kwargs={"max_target_length": data_args.max_target_length},
|
801 |
+
)
|
802 |
|
803 |
if training_args.do_eval:
|
804 |
if "validation" not in dataset:
|
|
|
808 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
809 |
if data_args.max_eval_samples is not None:
|
810 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
811 |
+
eval_dataset = eval_dataset.map(
|
812 |
+
tokenization_fn,
|
813 |
+
batched=True,
|
814 |
+
num_proc=data_args.preprocessing_num_workers,
|
815 |
+
# kept image paths
|
816 |
+
remove_columns=column_names.remove(image_column),
|
817 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
818 |
+
desc=f"Running tokenizer on validation dataset",
|
819 |
+
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
820 |
+
)
|
821 |
|
822 |
if training_args.do_predict:
|
823 |
if "test" not in dataset:
|
|
|
827 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
828 |
if data_args.max_predict_samples is not None:
|
829 |
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
830 |
+
predict_dataset = predict_dataset.map(
|
831 |
+
tokenization_fn,
|
832 |
+
batched=True,
|
833 |
+
num_proc=data_args.preprocessing_num_workers,
|
834 |
+
# kept image paths
|
835 |
+
remove_columns=column_names.remove(image_column),
|
836 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
837 |
+
desc=f"Running tokenizer on prediction dataset",
|
838 |
+
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
839 |
+
)
|
840 |
|
841 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
842 |
# data loader separately (in a sequential order).
|
|
|
882 |
else:
|
883 |
indices = jnp.arange(len(ds))
|
884 |
|
|
|
885 |
max_target_length = data_args.max_target_length
|
886 |
if split in ["valid", "test"]:
|
887 |
max_target_length = data_args.val_max_target_length
|
|
|
902 |
}
|
903 |
|
904 |
_ds =_ds.map(
|
905 |
+
# preprocess_function,
|
906 |
+
feature_extraction_fn,
|
907 |
batched=True,
|
908 |
num_proc=data_args.preprocessing_num_workers,
|
909 |
+
# remove_columns=column_names,
|
910 |
+
remove_columns=[image_column],
|
911 |
load_from_cache_file=not data_args.overwrite_cache,
|
912 |
features=features,
|
913 |
+
# desc=f"Running tokenizer on {names[split]} dataset".replace(" ", " "),
|
914 |
+
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
915 |
+
# fn_kwargs={"max_target_length": max_target_length},
|
916 |
)
|
917 |
_ds = _ds.with_format("numpy")
|
918 |
|