ydshieh
commited on
Commit
·
89bf8e9
1
Parent(s):
6b61700
update
Browse files- run_image_captioning_flax.py +53 -42
run_image_captioning_flax.py
CHANGED
@@ -112,7 +112,15 @@ class TrainingArguments:
|
|
112 |
per_device_eval_batch_size: int = field(
|
113 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
114 |
)
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
117 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
118 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
@@ -351,7 +359,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
351 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
352 |
Shuffle batches if `shuffle` is `True`.
|
353 |
"""
|
354 |
-
|
355 |
|
356 |
if shuffle:
|
357 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
@@ -359,7 +367,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
359 |
else:
|
360 |
batch_idx = np.arange(len(dataset))
|
361 |
|
362 |
-
for idx in range(
|
363 |
|
364 |
start_idx = batch_size * idx
|
365 |
end_idx = batch_size * (idx + 1)
|
@@ -661,37 +669,31 @@ def main():
|
|
661 |
"You can do it from another script, save it, and load it from here, using --feature_extractor_name."
|
662 |
)
|
663 |
|
664 |
-
|
665 |
-
|
666 |
-
tokenizer =
|
667 |
-
|
|
|
|
|
|
|
668 |
tokenizer = AutoTokenizer.from_pretrained(
|
669 |
-
model_args.
|
670 |
)
|
671 |
-
|
672 |
-
|
673 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
674 |
-
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
675 |
-
)
|
676 |
-
except ValueError as e:
|
677 |
-
logger.warning(e)
|
678 |
-
|
679 |
-
# Check decoder
|
680 |
-
if not tokenizer:
|
681 |
-
if model_args.decoder_model_name_or_path:
|
682 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
683 |
-
model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
684 |
-
)
|
685 |
-
else:
|
686 |
-
raise ValueError(
|
687 |
-
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
688 |
-
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
689 |
-
)
|
690 |
-
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
|
691 |
-
|
692 |
-
return tokenizer
|
693 |
|
694 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
|
696 |
# Preprocessing the datasets.
|
697 |
# We need to tokenize inputs and targets.
|
@@ -864,8 +866,6 @@ def main():
|
|
864 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
865 |
)
|
866 |
|
867 |
-
tokenizer = get_tokenizer()
|
868 |
-
|
869 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
870 |
# data loader separately (in a sequential order).
|
871 |
block_size = training_args.block_size
|
@@ -892,18 +892,26 @@ def main():
|
|
892 |
num_test_examples = len(predict_dataset)
|
893 |
test_steps = num_test_examples // eval_batch_size
|
894 |
|
895 |
-
def get_batch_iter(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
896 |
|
897 |
if not block_size:
|
898 |
block_size = len(ds)
|
899 |
|
900 |
-
|
901 |
num_examples = len(ds)
|
902 |
steps = num_examples // batch_size
|
903 |
-
num_splits = steps //
|
904 |
|
905 |
if shuffle:
|
906 |
-
indices = jax.random.permutation(rng, len(
|
907 |
indices = np.asarray(indices)
|
908 |
else:
|
909 |
indices = np.arange(len(ds))
|
@@ -1131,7 +1139,7 @@ def main():
|
|
1131 |
if not os.path.isdir(os.path.join(training_args.output_dir)):
|
1132 |
os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
|
1133 |
|
1134 |
-
def save_results(epoch, step):
|
1135 |
|
1136 |
# save checkpoint after each epoch and push checkpoint to the hub
|
1137 |
if jax.process_index() == 0:
|
@@ -1143,7 +1151,7 @@ def main():
|
|
1143 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
1144 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
1145 |
|
1146 |
-
def evaluation_loop(rng, dataset, split):
|
1147 |
|
1148 |
if split not in ["valid", "test"]:
|
1149 |
raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
|
@@ -1239,10 +1247,10 @@ def main():
|
|
1239 |
with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
|
1240 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
1241 |
|
1242 |
-
def evaluate(rng, dataset):
|
1243 |
evaluation_loop(rng, dataset, split='valid')
|
1244 |
|
1245 |
-
def predict(rng, dataset):
|
1246 |
evaluation_loop(rng, dataset, split='test')
|
1247 |
|
1248 |
input_rng = None
|
@@ -1292,7 +1300,8 @@ def main():
|
|
1292 |
if has_tensorboard and jax.process_index() == 0:
|
1293 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
1294 |
|
1295 |
-
# ======================== Evaluating ==============================
|
|
|
1296 |
if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
|
1297 |
evaluate(input_rng, eval_dataset)
|
1298 |
save_results(epoch, cur_step)
|
@@ -1311,6 +1320,8 @@ def main():
|
|
1311 |
if has_tensorboard and jax.process_index() == 0:
|
1312 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
1313 |
|
|
|
|
|
1314 |
if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
|
1315 |
evaluate(input_rng, eval_dataset)
|
1316 |
save_results(epoch, cur_step)
|
|
|
112 |
per_device_eval_batch_size: int = field(
|
113 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
114 |
)
|
115 |
+
_block_size_doc = \
|
116 |
+
"""
|
117 |
+
Split a dataset into chunks of size `block_size`. On each block, images are transformed by the feature extractor
|
118 |
+
and are kept in memory, and the batches of size `batch_size` are yield before processing the next block.
|
119 |
+
"""
|
120 |
+
block_size: int = field(
|
121 |
+
default=64,
|
122 |
+
metadata={"help": _block_size_doc}
|
123 |
+
)
|
124 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
125 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
126 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
|
359 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
360 |
Shuffle batches if `shuffle` is `True`.
|
361 |
"""
|
362 |
+
steps = len(dataset) // batch_size
|
363 |
|
364 |
if shuffle:
|
365 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
|
367 |
else:
|
368 |
batch_idx = np.arange(len(dataset))
|
369 |
|
370 |
+
for idx in range(steps):
|
371 |
|
372 |
start_idx = batch_size * idx
|
373 |
end_idx = batch_size * (idx + 1)
|
|
|
669 |
"You can do it from another script, save it, and load it from here, using --feature_extractor_name."
|
670 |
)
|
671 |
|
672 |
+
tokenizer = None
|
673 |
+
if model_args.tokenizer_name:
|
674 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
675 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
676 |
+
)
|
677 |
+
elif model_args.model_name_or_path:
|
678 |
+
try:
|
679 |
tokenizer = AutoTokenizer.from_pretrained(
|
680 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
681 |
)
|
682 |
+
except ValueError as e:
|
683 |
+
logger.warning(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
|
685 |
+
# Check decoder
|
686 |
+
if not tokenizer:
|
687 |
+
if model_args.decoder_model_name_or_path:
|
688 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
689 |
+
model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
690 |
+
)
|
691 |
+
else:
|
692 |
+
raise ValueError(
|
693 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
694 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
695 |
+
)
|
696 |
+
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
|
697 |
|
698 |
# Preprocessing the datasets.
|
699 |
# We need to tokenize inputs and targets.
|
|
|
866 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
867 |
)
|
868 |
|
|
|
|
|
869 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
870 |
# data loader separately (in a sequential order).
|
871 |
block_size = training_args.block_size
|
|
|
892 |
num_test_examples = len(predict_dataset)
|
893 |
test_steps = num_test_examples // eval_batch_size
|
894 |
|
895 |
+
def get_batch_iter(
|
896 |
+
rng: jax.random.PRNGKey,
|
897 |
+
ds: Dataset,
|
898 |
+
block_size: int,
|
899 |
+
batch_size: int,
|
900 |
+
shuffle: bool = False,
|
901 |
+
keep_in_memory: bool = False,
|
902 |
+
split: str = ""
|
903 |
+
):
|
904 |
|
905 |
if not block_size:
|
906 |
block_size = len(ds)
|
907 |
|
908 |
+
steps_per_block = block_size // batch_size
|
909 |
num_examples = len(ds)
|
910 |
steps = num_examples // batch_size
|
911 |
+
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
912 |
|
913 |
if shuffle:
|
914 |
+
indices = jax.random.permutation(rng, len(ds))
|
915 |
indices = np.asarray(indices)
|
916 |
else:
|
917 |
indices = np.arange(len(ds))
|
|
|
1139 |
if not os.path.isdir(os.path.join(training_args.output_dir)):
|
1140 |
os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
|
1141 |
|
1142 |
+
def save_results(epoch: int, step: int):
|
1143 |
|
1144 |
# save checkpoint after each epoch and push checkpoint to the hub
|
1145 |
if jax.process_index() == 0:
|
|
|
1151 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
1152 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
1153 |
|
1154 |
+
def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, split: str):
|
1155 |
|
1156 |
if split not in ["valid", "test"]:
|
1157 |
raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
|
|
|
1247 |
with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
|
1248 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
1249 |
|
1250 |
+
def evaluate(rng: jax.random.PRNGKey, dataset: Dataset):
|
1251 |
evaluation_loop(rng, dataset, split='valid')
|
1252 |
|
1253 |
+
def predict(rng: jax.random.PRNGKey, dataset: Dataset):
|
1254 |
evaluation_loop(rng, dataset, split='test')
|
1255 |
|
1256 |
input_rng = None
|
|
|
1300 |
if has_tensorboard and jax.process_index() == 0:
|
1301 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
1302 |
|
1303 |
+
# ======================== Evaluating (inside epoch) ==============================
|
1304 |
+
|
1305 |
if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
|
1306 |
evaluate(input_rng, eval_dataset)
|
1307 |
save_results(epoch, cur_step)
|
|
|
1320 |
if has_tensorboard and jax.process_index() == 0:
|
1321 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
1322 |
|
1323 |
+
# ======================== Evaluating (after each epoch) ==============================
|
1324 |
+
|
1325 |
if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
|
1326 |
evaluate(input_rng, eval_dataset)
|
1327 |
save_results(epoch, cur_step)
|