ydshieh commited on
Commit
a897ce1
·
1 Parent(s): 89bf8e9

imporve feature_extraction_fn

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +3 -11
run_image_captioning_flax.py CHANGED
@@ -779,17 +779,9 @@ def main():
779
 
780
  def feature_extraction_fn(examples):
781
 
782
- pixel_values = []
783
-
784
- for image_file in examples[image_column]:
785
- with Image.open(image_file) as image:
786
- encoder_inputs = feature_extractor(images=image, return_tensors="np")
787
- pixel_values.append(encoder_inputs.pixel_values)
788
-
789
- pixel_values = np.concatenate(pixel_values)
790
-
791
- model_inputs = examples
792
- model_inputs['pixel_values'] = pixel_values
793
 
794
  return model_inputs
795
 
 
779
 
780
  def feature_extraction_fn(examples):
781
 
782
+ images = [Image.open(image_file) for image_file in examples[image_column]]
783
+ encoder_inputs = feature_extractor(images=images, return_tensors="np")
784
+ model_inputs = {"pixel_values": encoder_inputs.pixel_values}
 
 
 
 
 
 
 
 
785
 
786
  return model_inputs
787