ydshieh commited on
Commit
fd1b4a2
·
1 Parent(s): cc3d5d3

improve code

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +12 -11
run_image_captioning_flax.py CHANGED
@@ -1157,23 +1157,20 @@ def main():
1157
  logger.info(f" Num test examples = {num_test_examples}")
1158
  logger.info(f" Instantaneous test batch size per device = {training_args.per_device_eval_batch_size}")
1159
  logger.info(f" Total test batch size (w. parallel & distributed) = {eval_batch_size}")
1160
- logger.info(f" Total train batch size (w. parallel & distributed) = {eval_batch_size}")
1161
  logger.info(f" Test steps = {test_steps}")
1162
 
1163
  # create output directory
1164
  if not os.path.isdir(os.path.join(training_args.output_dir)):
1165
  os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
1166
 
1167
- def save_results(epoch: int, step: int):
1168
 
1169
  # save checkpoint after each epoch and push checkpoint to the hub
1170
  if jax.process_index() == 0:
1171
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1172
- dir_name = f'ckpt_epoch_{epoch + 1}_step_{step}'
1173
- model.save_pretrained(os.path.join(training_args.output_dir, dir_name), params=params)
1174
- tokenizer.save_pretrained(os.path.join(training_args.output_dir, dir_name))
1175
  if training_args.push_to_hub:
1176
- commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
1177
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1178
 
1179
  def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, metric_key_prefix: str = "eval", ckpt_dir: str = "", is_prediction=False):
@@ -1242,7 +1239,7 @@ def main():
1242
  if metrics:
1243
  # Print metrics and update progress bar
1244
  desc = f"{'Predict' if is_prediction else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
1245
- if training_args.do_train:
1246
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
1247
  epochs.write(desc)
1248
  epochs.desc = desc
@@ -1338,8 +1335,10 @@ def main():
1338
  # ======================== Evaluating (inside an epoch) ==============================
1339
 
1340
  if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
1341
- evaluate(input_rng, eval_dataset, ckpt_dir=f"ckpt_epoch_{epoch + 1}_step_{cur_step}")
1342
- save_results(epoch, cur_step)
 
 
1343
 
1344
  # ======================== Epoch End ==============================
1345
 
@@ -1358,8 +1357,10 @@ def main():
1358
  # ======================== Evaluating (after each epoch) ==============================
1359
 
1360
  if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
1361
- evaluate(input_rng, eval_dataset, ckpt_dir=f"ckpt_epoch_{epoch + 1}_step_{cur_step}")
1362
- save_results(epoch, cur_step)
 
 
1363
 
1364
  # ======================== Evaluating | Predicting ==============================
1365
 
 
1157
  logger.info(f" Num test examples = {num_test_examples}")
1158
  logger.info(f" Instantaneous test batch size per device = {training_args.per_device_eval_batch_size}")
1159
  logger.info(f" Total test batch size (w. parallel & distributed) = {eval_batch_size}")
 
1160
  logger.info(f" Test steps = {test_steps}")
1161
 
1162
  # create output directory
1163
  if not os.path.isdir(os.path.join(training_args.output_dir)):
1164
  os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
1165
 
1166
+ def save_ckpt(ckpt_dir: str, commit_msg: str =""):
1167
 
1168
  # save checkpoint after each epoch and push checkpoint to the hub
1169
  if jax.process_index() == 0:
1170
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1171
+ model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
1172
+ tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
 
1173
  if training_args.push_to_hub:
 
1174
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1175
 
1176
  def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, metric_key_prefix: str = "eval", ckpt_dir: str = "", is_prediction=False):
 
1239
  if metrics:
1240
  # Print metrics and update progress bar
1241
  desc = f"{'Predict' if is_prediction else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
1242
+ if training_args.do_train and not is_prediction:
1243
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
1244
  epochs.write(desc)
1245
  epochs.desc = desc
 
1335
  # ======================== Evaluating (inside an epoch) ==============================
1336
 
1337
  if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
1338
+ ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
1339
+ commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
1340
+ evaluate(input_rng, eval_dataset, ckpt_dir)
1341
+ save_ckpt(ckpt_dir=ckpt_dir, commit_msg=commit_msg)
1342
 
1343
  # ======================== Epoch End ==============================
1344
 
 
1357
  # ======================== Evaluating (after each epoch) ==============================
1358
 
1359
  if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
1360
+ ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
1361
+ commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
1362
+ evaluate(input_rng, eval_dataset, ckpt_dir)
1363
+ save_ckpt(ckpt_dir=ckpt_dir, commit_msg=commit_msg)
1364
 
1365
  # ======================== Evaluating | Predicting ==============================
1366