ydshieh
commited on
Commit
·
b05f10c
1
Parent(s):
afddfdc
improve doc
Browse files- run_image_captioning_flax.py +11 -4
run_image_captioning_flax.py
CHANGED
|
@@ -365,7 +365,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
| 365 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 366 |
Shuffle batches if `shuffle` is `True`.
|
| 367 |
"""
|
| 368 |
-
steps = len(dataset) // batch_size
|
| 369 |
|
| 370 |
if shuffle:
|
| 371 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
@@ -924,7 +924,7 @@ def main():
|
|
| 924 |
num_test_examples = len(predict_dataset)
|
| 925 |
test_steps = num_test_examples // eval_batch_size
|
| 926 |
|
| 927 |
-
def
|
| 928 |
rng: jax.random.PRNGKey,
|
| 929 |
ds: Dataset,
|
| 930 |
block_size: int,
|
|
@@ -933,6 +933,13 @@ def main():
|
|
| 933 |
keep_in_memory: bool = False,
|
| 934 |
split: str = ""
|
| 935 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 936 |
|
| 937 |
if shuffle:
|
| 938 |
indices = jax.random.permutation(rng, len(ds))
|
|
@@ -1197,7 +1204,7 @@ def main():
|
|
| 1197 |
preds = []
|
| 1198 |
labels = []
|
| 1199 |
|
| 1200 |
-
batches =
|
| 1201 |
steps = len(dataset) // eval_batch_size
|
| 1202 |
for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
|
| 1203 |
# Model forward
|
|
@@ -1305,7 +1312,7 @@ def main():
|
|
| 1305 |
|
| 1306 |
train_metrics = []
|
| 1307 |
|
| 1308 |
-
train_batches =
|
| 1309 |
|
| 1310 |
# train
|
| 1311 |
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
|
|
|
| 365 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 366 |
Shuffle batches if `shuffle` is `True`.
|
| 367 |
"""
|
| 368 |
+
steps = len(dataset) // batch_size # Skip incomplete batch.
|
| 369 |
|
| 370 |
if shuffle:
|
| 371 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
|
|
| 924 |
num_test_examples = len(predict_dataset)
|
| 925 |
test_steps = num_test_examples // eval_batch_size
|
| 926 |
|
| 927 |
+
def blockwise_data_loader(
|
| 928 |
rng: jax.random.PRNGKey,
|
| 929 |
ds: Dataset,
|
| 930 |
block_size: int,
|
|
|
|
| 933 |
keep_in_memory: bool = False,
|
| 934 |
split: str = ""
|
| 935 |
):
|
| 936 |
+
"""
|
| 937 |
+
Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
|
| 938 |
+
|
| 939 |
+
If `block_size` > 0, it requires `ds` to have a column that gives image paths in order to perform image feature
|
| 940 |
+
extraction (with the column name being specified by `image_column`). The tokenization should be done before
|
| 941 |
+
training in this case.
|
| 942 |
+
"""
|
| 943 |
|
| 944 |
if shuffle:
|
| 945 |
indices = jax.random.permutation(rng, len(ds))
|
|
|
|
| 1204 |
preds = []
|
| 1205 |
labels = []
|
| 1206 |
|
| 1207 |
+
batches = blockwise_data_loader(rng, dataset, block_size=training_args.block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
|
| 1208 |
steps = len(dataset) // eval_batch_size
|
| 1209 |
for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
|
| 1210 |
# Model forward
|
|
|
|
| 1312 |
|
| 1313 |
train_metrics = []
|
| 1314 |
|
| 1315 |
+
train_batches = blockwise_data_loader(input_rng, train_dataset, block_size=training_args.block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, split="train")
|
| 1316 |
|
| 1317 |
# train
|
| 1318 |
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|