ydshieh commited on
Commit
6b61700
·
1 Parent(s): 0487060
Files changed (1) hide show
  1. run_image_captioning_flax.py +11 -11
run_image_captioning_flax.py CHANGED
@@ -359,13 +359,13 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
359
  else:
360
  batch_idx = np.arange(len(dataset))
361
 
362
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
363
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
364
 
365
- for idx in batch_idx:
366
- batch = dataset[idx]
367
- batch = {k: jnp.array(v) for k, v in batch.items()}
368
 
 
 
369
  batch = shard(batch)
370
 
371
  yield batch
@@ -886,20 +886,20 @@ def main():
886
 
887
  if training_args.do_eval:
888
  num_eval_examples = len(eval_dataset)
889
- eval_steps = num_eval_examples // eval_batch_size + int(num_eval_examples % eval_batch_size > 0)
890
 
891
  if training_args.do_predict:
892
  num_test_examples = len(predict_dataset)
893
- test_steps = num_test_examples // eval_batch_size + int(num_test_examples % eval_batch_size > 0)
894
 
895
- def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, drop_last_batch=False, keep_in_memory=False, split=""):
896
 
897
  if not block_size:
898
  block_size = len(ds)
899
 
900
  steps_per_split = block_size // batch_size
901
  num_examples = len(ds)
902
- steps = num_examples // batch_size + int(num_examples % batch_size > 0 and not drop_last_batch)
903
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
904
 
905
  if shuffle:
@@ -1155,7 +1155,7 @@ def main():
1155
  labels = []
1156
 
1157
  batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
1158
- steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
1159
  for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
1160
  # Model forward
1161
  batch = next(batches)
@@ -1262,7 +1262,7 @@ def main():
1262
 
1263
  train_metrics = []
1264
 
1265
- train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, drop_last_batch=training_args.dataloader_drop_last, split="train")
1266
 
1267
  # train
1268
  for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
 
359
  else:
360
  batch_idx = np.arange(len(dataset))
361
 
362
+ for idx in range(steps_per_epoch):
 
363
 
364
+ start_idx = batch_size * idx
365
+ end_idx = batch_size * (idx + 1)
 
366
 
367
+ selected_indices = batch_idx[start_idx:end_idx]
368
+ batch = dataset[selected_indices]
369
  batch = shard(batch)
370
 
371
  yield batch
 
886
 
887
  if training_args.do_eval:
888
  num_eval_examples = len(eval_dataset)
889
+ eval_steps = num_eval_examples // eval_batch_size
890
 
891
  if training_args.do_predict:
892
  num_test_examples = len(predict_dataset)
893
+ test_steps = num_test_examples // eval_batch_size
894
 
895
+ def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, keep_in_memory=False, split=""):
896
 
897
  if not block_size:
898
  block_size = len(ds)
899
 
900
  steps_per_split = block_size // batch_size
901
  num_examples = len(ds)
902
+ steps = num_examples // batch_size
903
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
904
 
905
  if shuffle:
 
1155
  labels = []
1156
 
1157
  batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
1158
+ steps = len(dataset) // eval_batch_size
1159
  for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
1160
  # Model forward
1161
  batch = next(batches)
 
1262
 
1263
  train_metrics = []
1264
 
1265
+ train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, split="train")
1266
 
1267
  # train
1268
  for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):