ydshieh commited on
Commit
efe62b9
·
1 Parent(s): 0c9b4f3
Files changed (1) hide show
  1. run_image_captioning_flax.py +3 -3
run_image_captioning_flax.py CHANGED
@@ -299,7 +299,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
299
  if shuffle:
300
  batch_idx = jax.random.permutation(rng, len(dataset))
301
  else:
302
- batch_idx = jnp.arange(len(dataset))
303
 
304
  batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
305
  batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
@@ -845,7 +845,7 @@ def main():
845
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
846
 
847
  if shuffle:
848
- indices = jax.random.permutation(input_rng, len(ds))
849
  else:
850
  indices = jnp.arange(len(ds))
851
 
@@ -864,7 +864,7 @@ def main():
864
  "test": "prediction",
865
  }
866
 
867
- _ds =_ds.map(
868
  feature_extraction_fn,
869
  batched=True,
870
  num_proc=data_args.preprocessing_num_workers,
 
299
  if shuffle:
300
  batch_idx = jax.random.permutation(rng, len(dataset))
301
  else:
302
+ batch_idx = np.arange(len(dataset))
303
 
304
  batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
305
  batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
 
845
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
846
 
847
  if shuffle:
848
+ indices = np.random.permutation(len(train_dataset))
849
  else:
850
  indices = jnp.arange(len(ds))
851
 
 
864
  "test": "prediction",
865
  }
866
 
867
+ _ds = _ds.map(
868
  feature_extraction_fn,
869
  batched=True,
870
  num_proc=data_args.preprocessing_num_workers,