ydshieh
commited on
Commit
·
efe62b9
1
Parent(s):
0c9b4f3
fix speed
Browse files
run_image_captioning_flax.py
CHANGED
@@ -299,7 +299,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
299 |
if shuffle:
|
300 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
301 |
else:
|
302 |
-
batch_idx =
|
303 |
|
304 |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
305 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
@@ -845,7 +845,7 @@ def main():
|
|
845 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
846 |
|
847 |
if shuffle:
|
848 |
-
indices =
|
849 |
else:
|
850 |
indices = jnp.arange(len(ds))
|
851 |
|
@@ -864,7 +864,7 @@ def main():
|
|
864 |
"test": "prediction",
|
865 |
}
|
866 |
|
867 |
-
_ds =_ds.map(
|
868 |
feature_extraction_fn,
|
869 |
batched=True,
|
870 |
num_proc=data_args.preprocessing_num_workers,
|
|
|
299 |
if shuffle:
|
300 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
301 |
else:
|
302 |
+
batch_idx = np.arange(len(dataset))
|
303 |
|
304 |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
305 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
|
845 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
846 |
|
847 |
if shuffle:
|
848 |
+
indices = np.random.permutation(len(train_dataset))
|
849 |
else:
|
850 |
indices = jnp.arange(len(ds))
|
851 |
|
|
|
864 |
"test": "prediction",
|
865 |
}
|
866 |
|
867 |
+
_ds = _ds.map(
|
868 |
feature_extraction_fn,
|
869 |
batched=True,
|
870 |
num_proc=data_args.preprocessing_num_workers,
|