ydshieh
commited on
Commit
·
e70a466
1
Parent(s):
4450aac
cleanup
Browse files- run_image_captioning_flax.py +1 -48
run_image_captioning_flax.py
CHANGED
@@ -680,6 +680,7 @@ def main():
|
|
680 |
|
681 |
return bools
|
682 |
|
|
|
683 |
def tokenization_fn(examples, max_target_length):
|
684 |
|
685 |
captions = []
|
@@ -728,43 +729,6 @@ def main():
|
|
728 |
|
729 |
return model_inputs
|
730 |
|
731 |
-
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
732 |
-
def preprocess_function(examples, max_target_length):
|
733 |
-
|
734 |
-
pixel_values = []
|
735 |
-
captions = []
|
736 |
-
for image_file, caption in zip(examples[image_column], examples[caption_column]):
|
737 |
-
with Image.open(image_file) as image:
|
738 |
-
try:
|
739 |
-
encoder_inputs = feature_extractor(images=image, return_tensors="np")
|
740 |
-
except:
|
741 |
-
continue
|
742 |
-
pixel_values.append(encoder_inputs.pixel_values)
|
743 |
-
captions.append(caption.lower() + ' ' + tokenizer.eos_token)
|
744 |
-
|
745 |
-
pixel_values = np.concatenate(pixel_values)
|
746 |
-
targets = captions
|
747 |
-
|
748 |
-
model_inputs = {}
|
749 |
-
model_inputs['pixel_values'] = pixel_values
|
750 |
-
|
751 |
-
# Setup the tokenizer for targets
|
752 |
-
with tokenizer.as_target_tokenizer():
|
753 |
-
labels = tokenizer(
|
754 |
-
targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
|
755 |
-
)
|
756 |
-
|
757 |
-
model_inputs["labels"] = labels["input_ids"]
|
758 |
-
decoder_input_ids = shift_tokens_right_fn(
|
759 |
-
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
760 |
-
)
|
761 |
-
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
762 |
-
|
763 |
-
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
764 |
-
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
765 |
-
|
766 |
-
return model_inputs
|
767 |
-
|
768 |
features = datasets.Features(
|
769 |
{
|
770 |
"pixel_values": datasets.Array3D(
|
@@ -874,18 +838,11 @@ def main():
|
|
874 |
steps = num_examples // batch_size + int(num_examples % batch_size > 0 and not drop_last_batch)
|
875 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
876 |
|
877 |
-
if drop_last_batch:
|
878 |
-
num_examples = steps * batch_size
|
879 |
-
|
880 |
if shuffle:
|
881 |
indices = jax.random.permutation(input_rng, len(ds))
|
882 |
else:
|
883 |
indices = jnp.arange(len(ds))
|
884 |
|
885 |
-
max_target_length = data_args.max_target_length
|
886 |
-
if split in ["valid", "test"]:
|
887 |
-
max_target_length = data_args.val_max_target_length
|
888 |
-
|
889 |
for idx in range(num_splits):
|
890 |
|
891 |
start_idx = block_size * idx
|
@@ -902,17 +859,13 @@ def main():
|
|
902 |
}
|
903 |
|
904 |
_ds =_ds.map(
|
905 |
-
# preprocess_function,
|
906 |
feature_extraction_fn,
|
907 |
batched=True,
|
908 |
num_proc=data_args.preprocessing_num_workers,
|
909 |
-
# remove_columns=column_names,
|
910 |
remove_columns=[image_column],
|
911 |
load_from_cache_file=not data_args.overwrite_cache,
|
912 |
features=features,
|
913 |
-
# desc=f"Running tokenizer on {names[split]} dataset".replace(" ", " "),
|
914 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
915 |
-
# fn_kwargs={"max_target_length": max_target_length},
|
916 |
)
|
917 |
_ds = _ds.with_format("numpy")
|
918 |
|
|
|
680 |
|
681 |
return bools
|
682 |
|
683 |
+
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
684 |
def tokenization_fn(examples, max_target_length):
|
685 |
|
686 |
captions = []
|
|
|
729 |
|
730 |
return model_inputs
|
731 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
732 |
features = datasets.Features(
|
733 |
{
|
734 |
"pixel_values": datasets.Array3D(
|
|
|
838 |
steps = num_examples // batch_size + int(num_examples % batch_size > 0 and not drop_last_batch)
|
839 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
840 |
|
|
|
|
|
|
|
841 |
if shuffle:
|
842 |
indices = jax.random.permutation(input_rng, len(ds))
|
843 |
else:
|
844 |
indices = jnp.arange(len(ds))
|
845 |
|
|
|
|
|
|
|
|
|
846 |
for idx in range(num_splits):
|
847 |
|
848 |
start_idx = block_size * idx
|
|
|
859 |
}
|
860 |
|
861 |
_ds =_ds.map(
|
|
|
862 |
feature_extraction_fn,
|
863 |
batched=True,
|
864 |
num_proc=data_args.preprocessing_num_workers,
|
|
|
865 |
remove_columns=[image_column],
|
866 |
load_from_cache_file=not data_args.overwrite_cache,
|
867 |
features=features,
|
|
|
868 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
|
|
869 |
)
|
870 |
_ds = _ds.with_format("numpy")
|
871 |
|