ydshieh
commited on
Commit
·
6b61700
1
Parent(s):
0487060
update
Browse files- run_image_captioning_flax.py +11 -11
run_image_captioning_flax.py
CHANGED
@@ -359,13 +359,13 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
359 |
else:
|
360 |
batch_idx = np.arange(len(dataset))
|
361 |
|
362 |
-
|
363 |
-
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
batch = {k: jnp.array(v) for k, v in batch.items()}
|
368 |
|
|
|
|
|
369 |
batch = shard(batch)
|
370 |
|
371 |
yield batch
|
@@ -886,20 +886,20 @@ def main():
|
|
886 |
|
887 |
if training_args.do_eval:
|
888 |
num_eval_examples = len(eval_dataset)
|
889 |
-
eval_steps = num_eval_examples // eval_batch_size
|
890 |
|
891 |
if training_args.do_predict:
|
892 |
num_test_examples = len(predict_dataset)
|
893 |
-
test_steps = num_test_examples // eval_batch_size
|
894 |
|
895 |
-
def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False,
|
896 |
|
897 |
if not block_size:
|
898 |
block_size = len(ds)
|
899 |
|
900 |
steps_per_split = block_size // batch_size
|
901 |
num_examples = len(ds)
|
902 |
-
steps = num_examples // batch_size
|
903 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
904 |
|
905 |
if shuffle:
|
@@ -1155,7 +1155,7 @@ def main():
|
|
1155 |
labels = []
|
1156 |
|
1157 |
batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
|
1158 |
-
steps = len(dataset) // eval_batch_size
|
1159 |
for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
|
1160 |
# Model forward
|
1161 |
batch = next(batches)
|
@@ -1262,7 +1262,7 @@ def main():
|
|
1262 |
|
1263 |
train_metrics = []
|
1264 |
|
1265 |
-
train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True,
|
1266 |
|
1267 |
# train
|
1268 |
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
|
|
359 |
else:
|
360 |
batch_idx = np.arange(len(dataset))
|
361 |
|
362 |
+
for idx in range(steps_per_epoch):
|
|
|
363 |
|
364 |
+
start_idx = batch_size * idx
|
365 |
+
end_idx = batch_size * (idx + 1)
|
|
|
366 |
|
367 |
+
selected_indices = batch_idx[start_idx:end_idx]
|
368 |
+
batch = dataset[selected_indices]
|
369 |
batch = shard(batch)
|
370 |
|
371 |
yield batch
|
|
|
886 |
|
887 |
if training_args.do_eval:
|
888 |
num_eval_examples = len(eval_dataset)
|
889 |
+
eval_steps = num_eval_examples // eval_batch_size
|
890 |
|
891 |
if training_args.do_predict:
|
892 |
num_test_examples = len(predict_dataset)
|
893 |
+
test_steps = num_test_examples // eval_batch_size
|
894 |
|
895 |
+
def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, keep_in_memory=False, split=""):
|
896 |
|
897 |
if not block_size:
|
898 |
block_size = len(ds)
|
899 |
|
900 |
steps_per_split = block_size // batch_size
|
901 |
num_examples = len(ds)
|
902 |
+
steps = num_examples // batch_size
|
903 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
904 |
|
905 |
if shuffle:
|
|
|
1155 |
labels = []
|
1156 |
|
1157 |
batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
|
1158 |
+
steps = len(dataset) // eval_batch_size
|
1159 |
for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
|
1160 |
# Model forward
|
1161 |
batch = next(batches)
|
|
|
1262 |
|
1263 |
train_metrics = []
|
1264 |
|
1265 |
+
train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, split="train")
|
1266 |
|
1267 |
# train
|
1268 |
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|