ydshieh
commited on
Commit
·
89bf8e9
1
Parent(s):
6b61700
update
Browse files- run_image_captioning_flax.py +53 -42
run_image_captioning_flax.py
CHANGED
|
@@ -112,7 +112,15 @@ class TrainingArguments:
|
|
| 112 |
per_device_eval_batch_size: int = field(
|
| 113 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
| 114 |
)
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
| 117 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
| 118 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
@@ -351,7 +359,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
| 351 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 352 |
Shuffle batches if `shuffle` is `True`.
|
| 353 |
"""
|
| 354 |
-
|
| 355 |
|
| 356 |
if shuffle:
|
| 357 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
@@ -359,7 +367,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
| 359 |
else:
|
| 360 |
batch_idx = np.arange(len(dataset))
|
| 361 |
|
| 362 |
-
for idx in range(
|
| 363 |
|
| 364 |
start_idx = batch_size * idx
|
| 365 |
end_idx = batch_size * (idx + 1)
|
|
@@ -661,37 +669,31 @@ def main():
|
|
| 661 |
"You can do it from another script, save it, and load it from here, using --feature_extractor_name."
|
| 662 |
)
|
| 663 |
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
tokenizer =
|
| 667 |
-
|
|
|
|
|
|
|
|
|
|
| 668 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 669 |
-
model_args.
|
| 670 |
)
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 674 |
-
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 675 |
-
)
|
| 676 |
-
except ValueError as e:
|
| 677 |
-
logger.warning(e)
|
| 678 |
-
|
| 679 |
-
# Check decoder
|
| 680 |
-
if not tokenizer:
|
| 681 |
-
if model_args.decoder_model_name_or_path:
|
| 682 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 683 |
-
model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 684 |
-
)
|
| 685 |
-
else:
|
| 686 |
-
raise ValueError(
|
| 687 |
-
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
| 688 |
-
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 689 |
-
)
|
| 690 |
-
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
|
| 691 |
-
|
| 692 |
-
return tokenizer
|
| 693 |
|
| 694 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
|
| 696 |
# Preprocessing the datasets.
|
| 697 |
# We need to tokenize inputs and targets.
|
|
@@ -864,8 +866,6 @@ def main():
|
|
| 864 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
| 865 |
)
|
| 866 |
|
| 867 |
-
tokenizer = get_tokenizer()
|
| 868 |
-
|
| 869 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
| 870 |
# data loader separately (in a sequential order).
|
| 871 |
block_size = training_args.block_size
|
|
@@ -892,18 +892,26 @@ def main():
|
|
| 892 |
num_test_examples = len(predict_dataset)
|
| 893 |
test_steps = num_test_examples // eval_batch_size
|
| 894 |
|
| 895 |
-
def get_batch_iter(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
|
| 897 |
if not block_size:
|
| 898 |
block_size = len(ds)
|
| 899 |
|
| 900 |
-
|
| 901 |
num_examples = len(ds)
|
| 902 |
steps = num_examples // batch_size
|
| 903 |
-
num_splits = steps //
|
| 904 |
|
| 905 |
if shuffle:
|
| 906 |
-
indices = jax.random.permutation(rng, len(
|
| 907 |
indices = np.asarray(indices)
|
| 908 |
else:
|
| 909 |
indices = np.arange(len(ds))
|
|
@@ -1131,7 +1139,7 @@ def main():
|
|
| 1131 |
if not os.path.isdir(os.path.join(training_args.output_dir)):
|
| 1132 |
os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
|
| 1133 |
|
| 1134 |
-
def save_results(epoch, step):
|
| 1135 |
|
| 1136 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 1137 |
if jax.process_index() == 0:
|
|
@@ -1143,7 +1151,7 @@ def main():
|
|
| 1143 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
| 1144 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
| 1145 |
|
| 1146 |
-
def evaluation_loop(rng, dataset, split):
|
| 1147 |
|
| 1148 |
if split not in ["valid", "test"]:
|
| 1149 |
raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
|
|
@@ -1239,10 +1247,10 @@ def main():
|
|
| 1239 |
with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
|
| 1240 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
| 1241 |
|
| 1242 |
-
def evaluate(rng, dataset):
|
| 1243 |
evaluation_loop(rng, dataset, split='valid')
|
| 1244 |
|
| 1245 |
-
def predict(rng, dataset):
|
| 1246 |
evaluation_loop(rng, dataset, split='test')
|
| 1247 |
|
| 1248 |
input_rng = None
|
|
@@ -1292,7 +1300,8 @@ def main():
|
|
| 1292 |
if has_tensorboard and jax.process_index() == 0:
|
| 1293 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
| 1294 |
|
| 1295 |
-
# ======================== Evaluating ==============================
|
|
|
|
| 1296 |
if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
|
| 1297 |
evaluate(input_rng, eval_dataset)
|
| 1298 |
save_results(epoch, cur_step)
|
|
@@ -1311,6 +1320,8 @@ def main():
|
|
| 1311 |
if has_tensorboard and jax.process_index() == 0:
|
| 1312 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
| 1313 |
|
|
|
|
|
|
|
| 1314 |
if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
|
| 1315 |
evaluate(input_rng, eval_dataset)
|
| 1316 |
save_results(epoch, cur_step)
|
|
|
|
| 112 |
per_device_eval_batch_size: int = field(
|
| 113 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
| 114 |
)
|
| 115 |
+
_block_size_doc = \
|
| 116 |
+
"""
|
| 117 |
+
Split a dataset into chunks of size `block_size`. On each block, images are transformed by the feature extractor
|
| 118 |
+
and are kept in memory, and the batches of size `batch_size` are yield before processing the next block.
|
| 119 |
+
"""
|
| 120 |
+
block_size: int = field(
|
| 121 |
+
default=64,
|
| 122 |
+
metadata={"help": _block_size_doc}
|
| 123 |
+
)
|
| 124 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
| 125 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
| 126 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
|
|
| 359 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 360 |
Shuffle batches if `shuffle` is `True`.
|
| 361 |
"""
|
| 362 |
+
steps = len(dataset) // batch_size
|
| 363 |
|
| 364 |
if shuffle:
|
| 365 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
|
|
| 367 |
else:
|
| 368 |
batch_idx = np.arange(len(dataset))
|
| 369 |
|
| 370 |
+
for idx in range(steps):
|
| 371 |
|
| 372 |
start_idx = batch_size * idx
|
| 373 |
end_idx = batch_size * (idx + 1)
|
|
|
|
| 669 |
"You can do it from another script, save it, and load it from here, using --feature_extractor_name."
|
| 670 |
)
|
| 671 |
|
| 672 |
+
tokenizer = None
|
| 673 |
+
if model_args.tokenizer_name:
|
| 674 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 675 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 676 |
+
)
|
| 677 |
+
elif model_args.model_name_or_path:
|
| 678 |
+
try:
|
| 679 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 680 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 681 |
)
|
| 682 |
+
except ValueError as e:
|
| 683 |
+
logger.warning(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
|
| 685 |
+
# Check decoder
|
| 686 |
+
if not tokenizer:
|
| 687 |
+
if model_args.decoder_model_name_or_path:
|
| 688 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 689 |
+
model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 690 |
+
)
|
| 691 |
+
else:
|
| 692 |
+
raise ValueError(
|
| 693 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
| 694 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 695 |
+
)
|
| 696 |
+
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
|
| 697 |
|
| 698 |
# Preprocessing the datasets.
|
| 699 |
# We need to tokenize inputs and targets.
|
|
|
|
| 866 |
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
| 867 |
)
|
| 868 |
|
|
|
|
|
|
|
| 869 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
| 870 |
# data loader separately (in a sequential order).
|
| 871 |
block_size = training_args.block_size
|
|
|
|
| 892 |
num_test_examples = len(predict_dataset)
|
| 893 |
test_steps = num_test_examples // eval_batch_size
|
| 894 |
|
| 895 |
+
def get_batch_iter(
|
| 896 |
+
rng: jax.random.PRNGKey,
|
| 897 |
+
ds: Dataset,
|
| 898 |
+
block_size: int,
|
| 899 |
+
batch_size: int,
|
| 900 |
+
shuffle: bool = False,
|
| 901 |
+
keep_in_memory: bool = False,
|
| 902 |
+
split: str = ""
|
| 903 |
+
):
|
| 904 |
|
| 905 |
if not block_size:
|
| 906 |
block_size = len(ds)
|
| 907 |
|
| 908 |
+
steps_per_block = block_size // batch_size
|
| 909 |
num_examples = len(ds)
|
| 910 |
steps = num_examples // batch_size
|
| 911 |
+
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
| 912 |
|
| 913 |
if shuffle:
|
| 914 |
+
indices = jax.random.permutation(rng, len(ds))
|
| 915 |
indices = np.asarray(indices)
|
| 916 |
else:
|
| 917 |
indices = np.arange(len(ds))
|
|
|
|
| 1139 |
if not os.path.isdir(os.path.join(training_args.output_dir)):
|
| 1140 |
os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
|
| 1141 |
|
| 1142 |
+
def save_results(epoch: int, step: int):
|
| 1143 |
|
| 1144 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 1145 |
if jax.process_index() == 0:
|
|
|
|
| 1151 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
| 1152 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
| 1153 |
|
| 1154 |
+
def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, split: str):
|
| 1155 |
|
| 1156 |
if split not in ["valid", "test"]:
|
| 1157 |
raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
|
|
|
|
| 1247 |
with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
|
| 1248 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
| 1249 |
|
| 1250 |
+
def evaluate(rng: jax.random.PRNGKey, dataset: Dataset):
|
| 1251 |
evaluation_loop(rng, dataset, split='valid')
|
| 1252 |
|
| 1253 |
+
def predict(rng: jax.random.PRNGKey, dataset: Dataset):
|
| 1254 |
evaluation_loop(rng, dataset, split='test')
|
| 1255 |
|
| 1256 |
input_rng = None
|
|
|
|
| 1300 |
if has_tensorboard and jax.process_index() == 0:
|
| 1301 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
| 1302 |
|
| 1303 |
+
# ======================== Evaluating (inside epoch) ==============================
|
| 1304 |
+
|
| 1305 |
if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
|
| 1306 |
evaluate(input_rng, eval_dataset)
|
| 1307 |
save_results(epoch, cur_step)
|
|
|
|
| 1320 |
if has_tensorboard and jax.process_index() == 0:
|
| 1321 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
| 1322 |
|
| 1323 |
+
# ======================== Evaluating (after each epoch) ==============================
|
| 1324 |
+
|
| 1325 |
if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
|
| 1326 |
evaluate(input_rng, eval_dataset)
|
| 1327 |
save_results(epoch, cur_step)
|