ydshieh commited on
Commit
eabb817
·
1 Parent(s): a897ce1

make block_size 0 work

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +58 -34
run_image_captioning_flax.py CHANGED
@@ -785,6 +785,14 @@ def main():
785
 
786
  return model_inputs
787
 
 
 
 
 
 
 
 
 
788
  features = datasets.Features(
789
  {
790
  "pixel_values": datasets.Array3D(
@@ -801,6 +809,10 @@ def main():
801
  }
802
  )
803
 
 
 
 
 
804
  if training_args.do_train:
805
  if "train" not in dataset:
806
  raise ValueError("--do_train requires a train dataset")
@@ -810,15 +822,18 @@ def main():
810
  if data_args.max_train_samples is not None:
811
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
812
  train_dataset = train_dataset.map(
813
- tokenization_fn,
814
  batched=True,
815
  num_proc=data_args.preprocessing_num_workers,
816
  # kept image paths
817
- remove_columns=[x for x in column_names if x != image_column],
818
  load_from_cache_file=not data_args.overwrite_cache,
819
  desc=f"Running tokenizer on train dataset",
820
  fn_kwargs={"max_target_length": data_args.max_target_length},
 
821
  )
 
 
822
 
823
  if training_args.do_eval:
824
  if "validation" not in dataset:
@@ -829,15 +844,18 @@ def main():
829
  if data_args.max_eval_samples is not None:
830
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
831
  eval_dataset = eval_dataset.map(
832
- tokenization_fn,
833
  batched=True,
834
  num_proc=data_args.preprocessing_num_workers,
835
  # kept image paths
836
- remove_columns=[x for x in column_names if x != image_column],
837
  load_from_cache_file=not data_args.overwrite_cache,
838
  desc=f"Running tokenizer on validation dataset",
839
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
 
840
  )
 
 
841
 
842
  if training_args.do_predict:
843
  if "test" not in dataset:
@@ -848,15 +866,18 @@ def main():
848
  if data_args.max_predict_samples is not None:
849
  predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
850
  predict_dataset = predict_dataset.map(
851
- tokenization_fn,
852
  batched=True,
853
  num_proc=data_args.preprocessing_num_workers,
854
  # kept image paths
855
- remove_columns=[x for x in column_names if x != image_column],
856
  load_from_cache_file=not data_args.overwrite_cache,
857
  desc=f"Running tokenizer on prediction dataset",
858
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
 
859
  )
 
 
860
 
861
  # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
862
  # data loader separately (in a sequential order).
@@ -894,46 +915,49 @@ def main():
894
  split: str = ""
895
  ):
896
 
897
- if not block_size:
898
- block_size = len(ds)
899
-
900
- steps_per_block = block_size // batch_size
901
- num_examples = len(ds)
902
- steps = num_examples // batch_size
903
- num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
904
-
905
  if shuffle:
906
  indices = jax.random.permutation(rng, len(ds))
907
  indices = np.asarray(indices)
908
  else:
909
  indices = np.arange(len(ds))
910
 
 
 
 
 
 
 
 
911
  for idx in range(num_splits):
912
 
913
- start_idx = block_size * idx
914
- end_idx = block_size * (idx + 1)
 
915
 
916
- selected_indices = indices[start_idx:end_idx]
 
917
 
918
- _ds = ds.select(selected_indices)
919
 
920
- names = {
921
- "train": "train",
922
- "valid": "validation",
923
- "test": "prediction",
924
- }
925
 
926
- _ds = _ds.map(
927
- feature_extraction_fn,
928
- batched=True,
929
- num_proc=data_args.preprocessing_num_workers,
930
- remove_columns=[image_column],
931
- load_from_cache_file=not data_args.overwrite_cache,
932
- features=features,
933
- keep_in_memory=keep_in_memory,
934
- desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
935
- )
936
- _ds = _ds.with_format("numpy")
 
 
 
 
 
 
937
 
938
  # No need to shuffle here
939
  loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
 
785
 
786
  return model_inputs
787
 
788
+ def preprocess_fn(examples, max_target_length):
789
+
790
+ model_inputs = {}
791
+ model_inputs.update(tokenization_fn(examples, max_target_length))
792
+ model_inputs.update(feature_extraction_fn(model_inputs))
793
+
794
+ return model_inputs
795
+
796
  features = datasets.Features(
797
  {
798
  "pixel_values": datasets.Array3D(
 
809
  }
810
  )
811
 
812
+ function_kwarg = preprocess_fn if not training_args.block_size else tokenization_fn
813
+ features_kwarg = features if not training_args.block_size else None
814
+ remove_columns_kwarg = [x for x in column_names if x != image_column or not training_args.block_size]
815
+
816
  if training_args.do_train:
817
  if "train" not in dataset:
818
  raise ValueError("--do_train requires a train dataset")
 
822
  if data_args.max_train_samples is not None:
823
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
824
  train_dataset = train_dataset.map(
825
+ function=function_kwarg,
826
  batched=True,
827
  num_proc=data_args.preprocessing_num_workers,
828
  # kept image paths
829
+ remove_columns=remove_columns_kwarg,
830
  load_from_cache_file=not data_args.overwrite_cache,
831
  desc=f"Running tokenizer on train dataset",
832
  fn_kwargs={"max_target_length": data_args.max_target_length},
833
+ features=features_kwarg,
834
  )
835
+ if not training_args.block_size:
836
+ train_dataset = train_dataset.with_format("numpy")
837
 
838
  if training_args.do_eval:
839
  if "validation" not in dataset:
 
844
  if data_args.max_eval_samples is not None:
845
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
846
  eval_dataset = eval_dataset.map(
847
+ function=function_kwarg,
848
  batched=True,
849
  num_proc=data_args.preprocessing_num_workers,
850
  # kept image paths
851
+ remove_columns=remove_columns_kwarg,
852
  load_from_cache_file=not data_args.overwrite_cache,
853
  desc=f"Running tokenizer on validation dataset",
854
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
855
+ features=features_kwarg,
856
  )
857
+ if not training_args.block_size:
858
+ eval_dataset = eval_dataset.with_format("numpy")
859
 
860
  if training_args.do_predict:
861
  if "test" not in dataset:
 
866
  if data_args.max_predict_samples is not None:
867
  predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
868
  predict_dataset = predict_dataset.map(
869
+ function=function_kwarg,
870
  batched=True,
871
  num_proc=data_args.preprocessing_num_workers,
872
  # kept image paths
873
+ remove_columns=remove_columns_kwarg,
874
  load_from_cache_file=not data_args.overwrite_cache,
875
  desc=f"Running tokenizer on prediction dataset",
876
  fn_kwargs={"max_target_length": data_args.val_max_target_length},
877
+ features=features_kwarg,
878
  )
879
+ if not training_args.block_size:
880
+ predict_dataset = predict_dataset.with_format("numpy")
881
 
882
  # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
883
  # data loader separately (in a sequential order).
 
915
  split: str = ""
916
  ):
917
 
 
 
 
 
 
 
 
 
918
  if shuffle:
919
  indices = jax.random.permutation(rng, len(ds))
920
  indices = np.asarray(indices)
921
  else:
922
  indices = np.arange(len(ds))
923
 
924
+ _block_size = len(ds) if not block_size else block_size
925
+
926
+ steps_per_block = _block_size // batch_size
927
+ num_examples = len(ds)
928
+ steps = num_examples // batch_size
929
+ num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
930
+
931
  for idx in range(num_splits):
932
 
933
+ if not block_size:
934
+ _ds = ds
935
+ else:
936
 
937
+ start_idx = block_size * idx
938
+ end_idx = block_size * (idx + 1)
939
 
940
+ selected_indices = indices[start_idx:end_idx]
941
 
942
+ _ds = ds.select(selected_indices)
 
 
 
 
943
 
944
+ names = {
945
+ "train": "train",
946
+ "valid": "validation",
947
+ "test": "prediction",
948
+ }
949
+
950
+ _ds = _ds.map(
951
+ feature_extraction_fn,
952
+ batched=True,
953
+ num_proc=data_args.preprocessing_num_workers,
954
+ remove_columns=[image_column],
955
+ load_from_cache_file=not data_args.overwrite_cache,
956
+ features=features,
957
+ keep_in_memory=keep_in_memory,
958
+ desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
959
+ )
960
+ _ds = _ds.with_format("numpy")
961
 
962
  # No need to shuffle here
963
  loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)