ydshieh commited on
Commit
8f31d11
·
1 Parent(s): 9f6265f

separate tokenization and feature extraction

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +85 -5
run_image_captioning_flax.py CHANGED
@@ -680,6 +680,54 @@ def main():
680
 
681
  return bools
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
684
  def preprocess_function(examples, max_target_length):
685
 
@@ -741,6 +789,16 @@ def main():
741
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
742
  if data_args.max_train_samples is not None:
743
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
 
 
 
 
 
 
 
 
 
 
744
 
745
  if training_args.do_eval:
746
  if "validation" not in dataset:
@@ -750,6 +808,16 @@ def main():
750
  eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
751
  if data_args.max_eval_samples is not None:
752
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
 
 
 
 
 
 
 
 
 
 
753
 
754
  if training_args.do_predict:
755
  if "test" not in dataset:
@@ -759,6 +827,16 @@ def main():
759
  predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
760
  if data_args.max_predict_samples is not None:
761
  predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
 
 
 
 
 
 
 
 
 
 
762
 
763
  # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
764
  # data loader separately (in a sequential order).
@@ -804,7 +882,6 @@ def main():
804
  else:
805
  indices = jnp.arange(len(ds))
806
 
807
- # Temporarily set max_target_length for training or evaluation/prediction.
808
  max_target_length = data_args.max_target_length
809
  if split in ["valid", "test"]:
810
  max_target_length = data_args.val_max_target_length
@@ -825,14 +902,17 @@ def main():
825
  }
826
 
827
  _ds =_ds.map(
828
- preprocess_function,
 
829
  batched=True,
830
  num_proc=data_args.preprocessing_num_workers,
831
- remove_columns=column_names,
 
832
  load_from_cache_file=not data_args.overwrite_cache,
833
  features=features,
834
- desc=f"Running tokenizer on {names[split]} dataset".replace(" ", " "),
835
- fn_kwargs={"max_target_length": max_target_length},
 
836
  )
837
  _ds = _ds.with_format("numpy")
838
 
 
680
 
681
  return bools
682
 
683
+ def tokenization_fn(examples, max_target_length):
684
+
685
+ captions = []
686
+ for caption in examples[caption_column]:
687
+ captions.append(caption.lower() + ' ' + tokenizer.eos_token)
688
+
689
+ targets = captions
690
+
691
+ model_inputs = {}
692
+
693
+ # Setup the tokenizer for targets
694
+ with tokenizer.as_target_tokenizer():
695
+ labels = tokenizer(
696
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
697
+ )
698
+
699
+ model_inputs["labels"] = labels["input_ids"]
700
+ decoder_input_ids = shift_tokens_right_fn(
701
+ labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
702
+ )
703
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
704
+
705
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
706
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
707
+
708
+ model_inputs[image_column] = examples[image_column]
709
+
710
+ return model_inputs
711
+
712
+ def feature_extraction_fn(examples):
713
+
714
+ pixel_values = []
715
+
716
+ for image_file in examples[image_column]:
717
+ with Image.open(image_file) as image:
718
+ try:
719
+ encoder_inputs = feature_extractor(images=image, return_tensors="np")
720
+ except:
721
+ continue
722
+ pixel_values.append(encoder_inputs.pixel_values)
723
+
724
+ pixel_values = np.concatenate(pixel_values)
725
+
726
+ model_inputs = examples
727
+ model_inputs['pixel_values'] = pixel_values
728
+
729
+ return model_inputs
730
+
731
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
732
  def preprocess_function(examples, max_target_length):
733
 
 
789
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
790
  if data_args.max_train_samples is not None:
791
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
792
+ train_dataset = train_dataset.map(
793
+ tokenization_fn,
794
+ batched=True,
795
+ num_proc=data_args.preprocessing_num_workers,
796
+ # kept image paths
797
+ remove_columns=column_names.remove(image_column),
798
+ load_from_cache_file=not data_args.overwrite_cache,
799
+ desc=f"Running tokenizer on train dataset",
800
+ fn_kwargs={"max_target_length": data_args.max_target_length},
801
+ )
802
 
803
  if training_args.do_eval:
804
  if "validation" not in dataset:
 
808
  eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
809
  if data_args.max_eval_samples is not None:
810
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
811
+ eval_dataset = eval_dataset.map(
812
+ tokenization_fn,
813
+ batched=True,
814
+ num_proc=data_args.preprocessing_num_workers,
815
+ # kept image paths
816
+ remove_columns=column_names.remove(image_column),
817
+ load_from_cache_file=not data_args.overwrite_cache,
818
+ desc=f"Running tokenizer on validation dataset",
819
+ fn_kwargs={"max_target_length": data_args.val_max_target_length},
820
+ )
821
 
822
  if training_args.do_predict:
823
  if "test" not in dataset:
 
827
  predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
828
  if data_args.max_predict_samples is not None:
829
  predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
830
+ predict_dataset = predict_dataset.map(
831
+ tokenization_fn,
832
+ batched=True,
833
+ num_proc=data_args.preprocessing_num_workers,
834
+ # kept image paths
835
+ remove_columns=column_names.remove(image_column),
836
+ load_from_cache_file=not data_args.overwrite_cache,
837
+ desc=f"Running tokenizer on prediction dataset",
838
+ fn_kwargs={"max_target_length": data_args.val_max_target_length},
839
+ )
840
 
841
  # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
842
  # data loader separately (in a sequential order).
 
882
  else:
883
  indices = jnp.arange(len(ds))
884
 
 
885
  max_target_length = data_args.max_target_length
886
  if split in ["valid", "test"]:
887
  max_target_length = data_args.val_max_target_length
 
902
  }
903
 
904
  _ds =_ds.map(
905
+ # preprocess_function,
906
+ feature_extraction_fn,
907
  batched=True,
908
  num_proc=data_args.preprocessing_num_workers,
909
+ # remove_columns=column_names,
910
+ remove_columns=[image_column],
911
  load_from_cache_file=not data_args.overwrite_cache,
912
  features=features,
913
+ # desc=f"Running tokenizer on {names[split]} dataset".replace(" ", " "),
914
+ desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
915
+ # fn_kwargs={"max_target_length": max_target_length},
916
  )
917
  _ds = _ds.with_format("numpy")
918