ydshieh commited on
Commit
18ea816
·
1 Parent(s): a4b085f
Files changed (1) hide show
  1. run_image_captioning_flax.py +4 -1
run_image_captioning_flax.py CHANGED
@@ -794,12 +794,15 @@ def main():
794
 
795
  if check:
796
  for k, v in examples.items():
797
- examples[k] = v[bools]
 
798
  else:
799
  assert len(images) == len(examples)
800
 
801
  encoder_inputs = feature_extractor(images=images, return_tensors="np")
802
  model_inputs = {"pixel_values": encoder_inputs.pixel_values}
 
 
803
 
804
  return model_inputs
805
 
 
794
 
795
  if check:
796
  for k, v in examples.items():
797
+ if k != image_column:
798
+ examples[k] = v[bools]
799
  else:
800
  assert len(images) == len(examples)
801
 
802
  encoder_inputs = feature_extractor(images=images, return_tensors="np")
803
  model_inputs = {"pixel_values": encoder_inputs.pixel_values}
804
+ model_inputs.update(examples)
805
+ model_inputs.pop(image_column)
806
 
807
  return model_inputs
808