ydshieh
commited on
Commit
·
fd1b4a2
1
Parent(s):
cc3d5d3
improve code
Browse files- run_image_captioning_flax.py +12 -11
run_image_captioning_flax.py
CHANGED
@@ -1157,23 +1157,20 @@ def main():
|
|
1157 |
logger.info(f" Num test examples = {num_test_examples}")
|
1158 |
logger.info(f" Instantaneous test batch size per device = {training_args.per_device_eval_batch_size}")
|
1159 |
logger.info(f" Total test batch size (w. parallel & distributed) = {eval_batch_size}")
|
1160 |
-
logger.info(f" Total train batch size (w. parallel & distributed) = {eval_batch_size}")
|
1161 |
logger.info(f" Test steps = {test_steps}")
|
1162 |
|
1163 |
# create output directory
|
1164 |
if not os.path.isdir(os.path.join(training_args.output_dir)):
|
1165 |
os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
|
1166 |
|
1167 |
-
def
|
1168 |
|
1169 |
# save checkpoint after each epoch and push checkpoint to the hub
|
1170 |
if jax.process_index() == 0:
|
1171 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
1172 |
-
|
1173 |
-
|
1174 |
-
tokenizer.save_pretrained(os.path.join(training_args.output_dir, dir_name))
|
1175 |
if training_args.push_to_hub:
|
1176 |
-
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
1177 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
1178 |
|
1179 |
def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, metric_key_prefix: str = "eval", ckpt_dir: str = "", is_prediction=False):
|
@@ -1242,7 +1239,7 @@ def main():
|
|
1242 |
if metrics:
|
1243 |
# Print metrics and update progress bar
|
1244 |
desc = f"{'Predict' if is_prediction else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
|
1245 |
-
if training_args.do_train:
|
1246 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
|
1247 |
epochs.write(desc)
|
1248 |
epochs.desc = desc
|
@@ -1338,8 +1335,10 @@ def main():
|
|
1338 |
# ======================== Evaluating (inside an epoch) ==============================
|
1339 |
|
1340 |
if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
|
1341 |
-
|
1342 |
-
|
|
|
|
|
1343 |
|
1344 |
# ======================== Epoch End ==============================
|
1345 |
|
@@ -1358,8 +1357,10 @@ def main():
|
|
1358 |
# ======================== Evaluating (after each epoch) ==============================
|
1359 |
|
1360 |
if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
|
1361 |
-
|
1362 |
-
|
|
|
|
|
1363 |
|
1364 |
# ======================== Evaluating | Predicting ==============================
|
1365 |
|
|
|
1157 |
logger.info(f" Num test examples = {num_test_examples}")
|
1158 |
logger.info(f" Instantaneous test batch size per device = {training_args.per_device_eval_batch_size}")
|
1159 |
logger.info(f" Total test batch size (w. parallel & distributed) = {eval_batch_size}")
|
|
|
1160 |
logger.info(f" Test steps = {test_steps}")
|
1161 |
|
1162 |
# create output directory
|
1163 |
if not os.path.isdir(os.path.join(training_args.output_dir)):
|
1164 |
os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
|
1165 |
|
1166 |
+
def save_ckpt(ckpt_dir: str, commit_msg: str =""):
|
1167 |
|
1168 |
# save checkpoint after each epoch and push checkpoint to the hub
|
1169 |
if jax.process_index() == 0:
|
1170 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
1171 |
+
model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
|
1172 |
+
tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
|
|
|
1173 |
if training_args.push_to_hub:
|
|
|
1174 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
1175 |
|
1176 |
def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, metric_key_prefix: str = "eval", ckpt_dir: str = "", is_prediction=False):
|
|
|
1239 |
if metrics:
|
1240 |
# Print metrics and update progress bar
|
1241 |
desc = f"{'Predict' if is_prediction else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
|
1242 |
+
if training_args.do_train and not is_prediction:
|
1243 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
|
1244 |
epochs.write(desc)
|
1245 |
epochs.desc = desc
|
|
|
1335 |
# ======================== Evaluating (inside an epoch) ==============================
|
1336 |
|
1337 |
if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
|
1338 |
+
ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
|
1339 |
+
commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
|
1340 |
+
evaluate(input_rng, eval_dataset, ckpt_dir)
|
1341 |
+
save_ckpt(ckpt_dir=ckpt_dir, commit_msg=commit_msg)
|
1342 |
|
1343 |
# ======================== Epoch End ==============================
|
1344 |
|
|
|
1357 |
# ======================== Evaluating (after each epoch) ==============================
|
1358 |
|
1359 |
if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
|
1360 |
+
ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
|
1361 |
+
commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
|
1362 |
+
evaluate(input_rng, eval_dataset, ckpt_dir)
|
1363 |
+
save_ckpt(ckpt_dir=ckpt_dir, commit_msg=commit_msg)
|
1364 |
|
1365 |
# ======================== Evaluating | Predicting ==============================
|
1366 |
|