ydshieh
commited on
Commit
·
91d8939
1
Parent(s):
0b49c18
fix
Browse files
run_image_captioning_flax.py
CHANGED
|
@@ -929,6 +929,9 @@ def main():
|
|
| 929 |
|
| 930 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 931 |
|
|
|
|
|
|
|
|
|
|
| 932 |
if training_args.do_train:
|
| 933 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 934 |
num_train_examples_per_epoch = steps_per_epoch * train_batch_size
|
|
|
|
| 929 |
|
| 930 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 931 |
|
| 932 |
+
if training_args.block_size % train_batch_size > 0:
|
| 933 |
+
raise ValueError(f"`training_args.block_size` needs to be a multiple of the global batch size. Got {training_args.block_size} and {train_batch_size} instead.")
|
| 934 |
+
|
| 935 |
if training_args.do_train:
|
| 936 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 937 |
num_train_examples_per_epoch = steps_per_epoch * train_batch_size
|