ydshieh
commited on
Commit
·
283180e
1
Parent(s):
cb2b435
more general
Browse files
run_image_captioning_flax.py
CHANGED
@@ -633,6 +633,8 @@ def main():
|
|
633 |
|
634 |
return tokenizer
|
635 |
|
|
|
|
|
636 |
# Preprocessing the datasets.
|
637 |
# We need to tokenize inputs and targets.
|
638 |
if training_args.do_train:
|
@@ -688,8 +690,6 @@ def main():
|
|
688 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
689 |
def tokenization_fn(examples, max_target_length):
|
690 |
|
691 |
-
tokenizer = get_tokenizer()
|
692 |
-
|
693 |
captions = []
|
694 |
for caption in examples[caption_column]:
|
695 |
captions.append(caption.lower() + ' ' + tokenizer.eos_token)
|
@@ -834,7 +834,7 @@ def main():
|
|
834 |
num_test_examples = len(predict_dataset)
|
835 |
test_steps = num_test_examples // eval_batch_size + int(num_test_examples % eval_batch_size > 0)
|
836 |
|
837 |
-
def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, drop_last_batch=False, split=""):
|
838 |
|
839 |
if not block_size:
|
840 |
block_size = len(ds)
|
@@ -871,7 +871,7 @@ def main():
|
|
871 |
remove_columns=[image_column],
|
872 |
load_from_cache_file=not data_args.overwrite_cache,
|
873 |
features=features,
|
874 |
-
keep_in_memory=
|
875 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
876 |
)
|
877 |
_ds = _ds.with_format("numpy")
|
@@ -1095,7 +1095,7 @@ def main():
|
|
1095 |
preds = []
|
1096 |
labels = []
|
1097 |
|
1098 |
-
batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, shuffle=False, split=name)
|
1099 |
steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
|
1100 |
for _ in tqdm(range(steps), desc=f"{'Predicting' if name == 'test' else 'Evaluating'}...", position=2, leave=False):
|
1101 |
# Model forward
|
@@ -1197,7 +1197,7 @@ def main():
|
|
1197 |
|
1198 |
train_metrics = []
|
1199 |
|
1200 |
-
train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, shuffle=True, drop_last_batch=training_args.dataloader_drop_last, split="train")
|
1201 |
|
1202 |
# train
|
1203 |
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
|
|
633 |
|
634 |
return tokenizer
|
635 |
|
636 |
+
tokenizer = get_tokenizer()
|
637 |
+
|
638 |
# Preprocessing the datasets.
|
639 |
# We need to tokenize inputs and targets.
|
640 |
if training_args.do_train:
|
|
|
690 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
691 |
def tokenization_fn(examples, max_target_length):
|
692 |
|
|
|
|
|
693 |
captions = []
|
694 |
for caption in examples[caption_column]:
|
695 |
captions.append(caption.lower() + ' ' + tokenizer.eos_token)
|
|
|
834 |
num_test_examples = len(predict_dataset)
|
835 |
test_steps = num_test_examples // eval_batch_size + int(num_test_examples % eval_batch_size > 0)
|
836 |
|
837 |
+
def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, drop_last_batch=False, keep_in_memory=False, split=""):
|
838 |
|
839 |
if not block_size:
|
840 |
block_size = len(ds)
|
|
|
871 |
remove_columns=[image_column],
|
872 |
load_from_cache_file=not data_args.overwrite_cache,
|
873 |
features=features,
|
874 |
+
keep_in_memory=keep_in_memory,
|
875 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
876 |
)
|
877 |
_ds = _ds.with_format("numpy")
|
|
|
1095 |
preds = []
|
1096 |
labels = []
|
1097 |
|
1098 |
+
batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=name)
|
1099 |
steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
|
1100 |
for _ in tqdm(range(steps), desc=f"{'Predicting' if name == 'test' else 'Evaluating'}...", position=2, leave=False):
|
1101 |
# Model forward
|
|
|
1197 |
|
1198 |
train_metrics = []
|
1199 |
|
1200 |
+
train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, drop_last_batch=training_args.dataloader_drop_last, split="train")
|
1201 |
|
1202 |
# train
|
1203 |
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|