ydshieh commited on
Commit
e70a466
·
1 Parent(s): 4450aac
Files changed (1) hide show
  1. run_image_captioning_flax.py +1 -48
run_image_captioning_flax.py CHANGED
@@ -680,6 +680,7 @@ def main():
680
 
681
  return bools
682
 
 
683
  def tokenization_fn(examples, max_target_length):
684
 
685
  captions = []
@@ -728,43 +729,6 @@ def main():
728
 
729
  return model_inputs
730
 
731
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
732
- def preprocess_function(examples, max_target_length):
733
-
734
- pixel_values = []
735
- captions = []
736
- for image_file, caption in zip(examples[image_column], examples[caption_column]):
737
- with Image.open(image_file) as image:
738
- try:
739
- encoder_inputs = feature_extractor(images=image, return_tensors="np")
740
- except:
741
- continue
742
- pixel_values.append(encoder_inputs.pixel_values)
743
- captions.append(caption.lower() + ' ' + tokenizer.eos_token)
744
-
745
- pixel_values = np.concatenate(pixel_values)
746
- targets = captions
747
-
748
- model_inputs = {}
749
- model_inputs['pixel_values'] = pixel_values
750
-
751
- # Setup the tokenizer for targets
752
- with tokenizer.as_target_tokenizer():
753
- labels = tokenizer(
754
- targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
755
- )
756
-
757
- model_inputs["labels"] = labels["input_ids"]
758
- decoder_input_ids = shift_tokens_right_fn(
759
- labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
760
- )
761
- model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
762
-
763
- # We need decoder_attention_mask so we can ignore pad tokens from loss
764
- model_inputs["decoder_attention_mask"] = labels["attention_mask"]
765
-
766
- return model_inputs
767
-
768
  features = datasets.Features(
769
  {
770
  "pixel_values": datasets.Array3D(
@@ -874,18 +838,11 @@ def main():
874
  steps = num_examples // batch_size + int(num_examples % batch_size > 0 and not drop_last_batch)
875
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
876
 
877
- if drop_last_batch:
878
- num_examples = steps * batch_size
879
-
880
  if shuffle:
881
  indices = jax.random.permutation(input_rng, len(ds))
882
  else:
883
  indices = jnp.arange(len(ds))
884
 
885
- max_target_length = data_args.max_target_length
886
- if split in ["valid", "test"]:
887
- max_target_length = data_args.val_max_target_length
888
-
889
  for idx in range(num_splits):
890
 
891
  start_idx = block_size * idx
@@ -902,17 +859,13 @@ def main():
902
  }
903
 
904
  _ds =_ds.map(
905
- # preprocess_function,
906
  feature_extraction_fn,
907
  batched=True,
908
  num_proc=data_args.preprocessing_num_workers,
909
- # remove_columns=column_names,
910
  remove_columns=[image_column],
911
  load_from_cache_file=not data_args.overwrite_cache,
912
  features=features,
913
- # desc=f"Running tokenizer on {names[split]} dataset".replace(" ", " "),
914
  desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
915
- # fn_kwargs={"max_target_length": max_target_length},
916
  )
917
  _ds = _ds.with_format("numpy")
918
 
 
680
 
681
  return bools
682
 
683
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
684
  def tokenization_fn(examples, max_target_length):
685
 
686
  captions = []
 
729
 
730
  return model_inputs
731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  features = datasets.Features(
733
  {
734
  "pixel_values": datasets.Array3D(
 
838
  steps = num_examples // batch_size + int(num_examples % batch_size > 0 and not drop_last_batch)
839
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
840
 
 
 
 
841
  if shuffle:
842
  indices = jax.random.permutation(input_rng, len(ds))
843
  else:
844
  indices = jnp.arange(len(ds))
845
 
 
 
 
 
846
  for idx in range(num_splits):
847
 
848
  start_idx = block_size * idx
 
859
  }
860
 
861
  _ds =_ds.map(
 
862
  feature_extraction_fn,
863
  batched=True,
864
  num_proc=data_args.preprocessing_num_workers,
 
865
  remove_columns=[image_column],
866
  load_from_cache_file=not data_args.overwrite_cache,
867
  features=features,
 
868
  desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
 
869
  )
870
  _ds = _ds.with_format("numpy")
871