ydshieh commited on
Commit
10c0ce9
·
1 Parent(s): cff52cf
Files changed (1) hide show
  1. run_image_captioning_flax.py +50 -40
run_image_captioning_flax.py CHANGED
@@ -123,11 +123,9 @@ class TrainingArguments:
123
  label_smoothing_factor: float = field(
124
  default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
125
  )
126
- adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
127
  num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
128
  warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
129
  logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
130
- save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
131
  eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
132
  dataloader_drop_last: bool = field(
133
  default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
@@ -1155,24 +1153,24 @@ def main():
1155
  commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
1156
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1157
 
1158
- def run_eval_or_test(rng, dataset, name, is_inside_training=True):
1159
 
1160
- if name not in ["valid", "test"]:
1161
- raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {name} instead.")
1162
 
1163
- logger.info(f"*** {'Predict' if name == 'test' else 'Evaluate'} ***")
1164
 
1165
  metrics = []
1166
  preds = []
1167
  labels = []
1168
 
1169
- batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=name)
1170
  steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
1171
- for _ in tqdm(range(steps), desc=f"{'Predicting' if name == 'test' else 'Evaluating'}...", position=2, leave=False):
1172
  # Model forward
1173
  batch = next(batches)
1174
  _labels = batch.get("labels", None)
1175
- if name == "valid" and _labels is None:
1176
  raise ValueError("Validation dataset requires `labels`")
1177
 
1178
  if _labels is not None:
@@ -1198,7 +1196,7 @@ def main():
1198
  if labels:
1199
  rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
1200
  metrics.update(rouge_metrics)
1201
- rouge_desc = " ".join([f"{'Predict' if name == 'test' else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
1202
  for pred, label in zip(decoded_preds, decoded_labels):
1203
  pred = pred.replace("\n", " ")
1204
  label = label.replace("\n", " ")
@@ -1215,8 +1213,8 @@ def main():
1215
 
1216
  if metrics:
1217
  # Print metrics and update progress bar
1218
- desc = f"{'Predict' if name == 'test' else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
1219
- if is_inside_training:
1220
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
1221
  epochs.write(desc)
1222
  epochs.desc = desc
@@ -1225,7 +1223,7 @@ def main():
1225
  if jax.process_index() == 0:
1226
 
1227
  ckpt_dir = ""
1228
- if is_inside_training:
1229
  ckpt_dir = f'ckpt_epoch_{epoch + 1}_step_{cur_step}'
1230
  if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
1231
  os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
@@ -1233,9 +1231,9 @@ def main():
1233
  if metrics:
1234
 
1235
  # save final metrics in json
1236
- metrics = {f"{name}_{metric_name}": round(value.item(), 6) for metric_name, value in metrics.items()}
1237
- path = os.path.join(training_args.output_dir, ckpt_dir, f"{name}_results.json")
1238
- with open(path, "w") as f:
1239
  json.dump(metrics, f, indent=4, sort_keys=True)
1240
 
1241
  # Update report
@@ -1243,14 +1241,20 @@ def main():
1243
  fp.write(desc + '\n')
1244
 
1245
  # Save metrics
1246
- if has_tensorboard and is_inside_training:
1247
- write_metric(summary_writer, name, metrics, cur_step)
1248
 
1249
  # Save generations
1250
  if generations:
1251
- with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{name}.json'), 'w', encoding='UTF-8') as fp:
1252
  json.dump(generations, fp, ensure_ascii=False, indent=4)
1253
 
 
 
 
 
 
 
1254
  input_rng = None
1255
 
1256
  if training_args.do_train:
@@ -1280,15 +1284,17 @@ def main():
1280
  train_metrics.append(train_metric)
1281
  train_time += time.time() - batch_start
1282
 
1283
- if cur_step % training_args.logging_steps == 0 or (training_args.eval_steps is not None and cur_step % training_args.eval_steps == 0) or cur_step % steps_per_epoch == 0:
 
 
 
 
1284
 
1285
- time_per_step = train_time / cur_step
 
1286
 
1287
- _train_metric = unreplicate(train_metric)
1288
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
1289
- epochs.desc = desc
1290
- epochs.write(desc)
1291
  logger.info(desc)
 
1292
  with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
1293
  fp.write(desc + '\n')
1294
 
@@ -1296,34 +1302,38 @@ def main():
1296
  if has_tensorboard and jax.process_index() == 0:
1297
  write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1298
 
1299
- # ======================== Evaluating ==============================
 
 
 
1300
 
1301
- if training_args.do_eval and ((training_args.eval_steps is not None and cur_step % training_args.eval_steps == 0) or cur_step % steps_per_epoch == 0):
1302
- run_eval_or_test(input_rng, eval_dataset, name="valid", is_inside_training=True)
1303
 
1304
- # ======================== Prediction loop ==============================
 
1305
 
1306
- # run prediction after evaluation if specified, otherwise only after each epoch
1307
- if training_args.do_predict and training_args.do_predict_during_training and training_args.do_predict_after_evaluation:
1308
- run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
1309
 
1310
- # ======================== Save ==============================
 
1311
 
1312
- if cur_step % training_args.save_steps == 0:
1313
- save_results(epoch + 1, cur_step)
 
1314
 
1315
- # run prediction after each epoch (if not done during training)
1316
- if training_args.do_predict and training_args.do_predict_during_training and not training_args.do_predict_after_evaluation:
1317
- run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
1318
  save_results(epoch + 1, cur_step)
1319
 
 
 
1320
  # Create sampling rng
1321
  if input_rng is None:
1322
  rng, input_rng = jax.random.split(rng)
1323
 
1324
  # run prediction after each epoch (if not done during training)
1325
- if training_args.do_predict and not (training_args.do_train and training_args.do_predict_during_training):
1326
- run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=False)
1327
 
1328
 
1329
  if __name__ == "__main__":
 
123
  label_smoothing_factor: float = field(
124
  default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
125
  )
 
126
  num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
127
  warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
128
  logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
 
129
  eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
130
  dataloader_drop_last: bool = field(
131
  default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
 
1153
  commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
1154
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1155
 
1156
+ def evaluation_loop(rng, dataset, split):
1157
 
1158
+ if split not in ["valid", "test"]:
1159
+ raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
1160
 
1161
+ logger.info(f"*** {'Predict' if split == 'test' else 'Evaluate'} ***")
1162
 
1163
  metrics = []
1164
  preds = []
1165
  labels = []
1166
 
1167
+ batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
1168
  steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
1169
+ for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
1170
  # Model forward
1171
  batch = next(batches)
1172
  _labels = batch.get("labels", None)
1173
+ if split == "valid" and _labels is None:
1174
  raise ValueError("Validation dataset requires `labels`")
1175
 
1176
  if _labels is not None:
 
1196
  if labels:
1197
  rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
1198
  metrics.update(rouge_metrics)
1199
+ rouge_desc = " ".join([f"{'Predict' if split == 'test' else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
1200
  for pred, label in zip(decoded_preds, decoded_labels):
1201
  pred = pred.replace("\n", " ")
1202
  label = label.replace("\n", " ")
 
1213
 
1214
  if metrics:
1215
  # Print metrics and update progress bar
1216
+ desc = f"{'Predict' if split == 'test' else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
1217
+ if split == "valid":
1218
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
1219
  epochs.write(desc)
1220
  epochs.desc = desc
 
1223
  if jax.process_index() == 0:
1224
 
1225
  ckpt_dir = ""
1226
+ if split == "valid":
1227
  ckpt_dir = f'ckpt_epoch_{epoch + 1}_step_{cur_step}'
1228
  if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
1229
  os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
 
1231
  if metrics:
1232
 
1233
  # save final metrics in json
1234
+ metrics = {f"{split}_{metric_name}": round(value.item(), 6) for metric_name, value in metrics.items()}
1235
+ _path = os.path.join(training_args.output_dir, ckpt_dir, f"{split}_results.json")
1236
+ with open(_path, "w") as f:
1237
  json.dump(metrics, f, indent=4, sort_keys=True)
1238
 
1239
  # Update report
 
1241
  fp.write(desc + '\n')
1242
 
1243
  # Save metrics
1244
+ if has_tensorboard and split == "valid":
1245
+ write_metric(summary_writer, split, metrics, cur_step)
1246
 
1247
  # Save generations
1248
  if generations:
1249
+ with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
1250
  json.dump(generations, fp, ensure_ascii=False, indent=4)
1251
 
1252
+ def evaluate(rng, dataset):
1253
+ evaluation_loop(rng, dataset, split='eval')
1254
+
1255
+ def predict(rng, dataset):
1256
+ evaluation_loop(rng, dataset, split='test')
1257
+
1258
  input_rng = None
1259
 
1260
  if training_args.do_train:
 
1284
  train_metrics.append(train_metric)
1285
  train_time += time.time() - batch_start
1286
 
1287
+ time_per_step = train_time / cur_step
1288
+ _train_metric = unreplicate(train_metric)
1289
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
1290
+ epochs.desc = desc
1291
+ epochs.write(desc)
1292
 
1293
+ # log and save info
1294
+ if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
1295
 
 
 
 
 
1296
  logger.info(desc)
1297
+
1298
  with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
1299
  fp.write(desc + '\n')
1300
 
 
1302
  if has_tensorboard and jax.process_index() == 0:
1303
  write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1304
 
1305
+ # ======================== Evaluating ==============================
1306
+ if training_args.eval_steps is not None and training_args.eval_steps > 0 and cur_step % training_args.eval_steps == 0:
1307
+ evaluate(input_rng, eval_dataset)
1308
+ save_results(epoch + 1, cur_step)
1309
 
1310
+ # ======================== Epoch End ==============================
 
1311
 
1312
+ # log and save info
1313
+ if training_args.logging_steps <= 0:
1314
 
1315
+ logger.info(desc)
 
 
1316
 
1317
+ with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
1318
+ fp.write(desc + '\n')
1319
 
1320
+ # Save metrics
1321
+ if has_tensorboard and jax.process_index() == 0:
1322
+ write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1323
 
1324
+ if training_args.eval_steps is None or training_args.eval_steps <= 0:
1325
+ evaluate(input_rng, eval_dataset)
 
1326
  save_results(epoch + 1, cur_step)
1327
 
1328
+ # ======================== Prediction loop ==============================
1329
+
1330
  # Create sampling rng
1331
  if input_rng is None:
1332
  rng, input_rng = jax.random.split(rng)
1333
 
1334
  # run prediction after each epoch (if not done during training)
1335
+ if training_args.do_predict:
1336
+ predict(input_rng, predict_dataset)
1337
 
1338
 
1339
  if __name__ == "__main__":