ydshieh commited on
Commit
7f250ee
·
1 Parent(s): fa45968

improve code

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +42 -34
run_image_captioning_flax.py CHANGED
@@ -957,12 +957,6 @@ def main():
957
 
958
  _ds = ds.select(selected_indices)
959
 
960
- names = {
961
- "train": "train",
962
- "valid": "validation",
963
- "test": "prediction",
964
- }
965
-
966
  _ds = _ds.map(
967
  feature_extraction_fn,
968
  batched=True,
@@ -971,7 +965,7 @@ def main():
971
  load_from_cache_file=not data_args.overwrite_cache,
972
  features=features,
973
  keep_in_memory=keep_in_memory,
974
- desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
975
  )
976
  _ds = _ds.with_format("numpy")
977
 
@@ -1183,25 +1177,30 @@ def main():
1183
  commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
1184
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1185
 
1186
- def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, split: str):
1187
-
1188
- if split not in ["valid", "test"]:
1189
- raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
1190
 
1191
- logger.info(f"*** {'Predict' if split == 'test' else 'Evaluate'} ***")
1192
 
1193
  metrics = []
1194
  preds = []
1195
  labels = []
1196
 
1197
- batches = blockwise_data_loader(rng, dataset, block_size=training_args.block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
 
 
 
 
 
 
 
 
1198
  steps = len(dataset) // eval_batch_size
1199
- for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
1200
  # Model forward
1201
  batch = next(batches)
1202
  _labels = batch.get("labels", None)
1203
- if split == "valid" and _labels is None:
1204
- raise ValueError("Validation dataset requires `labels`")
1205
 
1206
  if _labels is not None:
1207
  _metrics = p_eval_step(state.params, batch)
@@ -1226,7 +1225,7 @@ def main():
1226
  if labels:
1227
  rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
1228
  metrics.update(rouge_metrics)
1229
- rouge_desc = " ".join([f"{'Predict' if split == 'test' else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
1230
  for pred, label in zip(decoded_preds, decoded_labels):
1231
  pred = pred.replace("\n", " ")
1232
  label = label.replace("\n", " ")
@@ -1243,8 +1242,8 @@ def main():
1243
 
1244
  if metrics:
1245
  # Print metrics and update progress bar
1246
- desc = f"{'Predict' if split == 'test' else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
1247
- if split == "valid":
1248
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
1249
  epochs.write(desc)
1250
  epochs.desc = desc
@@ -1252,11 +1251,8 @@ def main():
1252
 
1253
  if jax.process_index() == 0:
1254
 
1255
- ckpt_dir = ""
1256
- if split == "valid":
1257
- ckpt_dir = f'ckpt_epoch_{epoch + 1}_step_{cur_step}'
1258
- if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
1259
- os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
1260
 
1261
  if metrics:
1262
 
@@ -1271,7 +1267,7 @@ def main():
1271
  fp.write(desc + '\n')
1272
 
1273
  # Save metrics
1274
- if has_tensorboard and split == "valid":
1275
  write_metric(summary_writer, split, metrics, cur_step)
1276
 
1277
  # Save generations
@@ -1279,11 +1275,11 @@ def main():
1279
  with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
1280
  json.dump(generations, fp, ensure_ascii=False, indent=4)
1281
 
1282
- def evaluate(rng: jax.random.PRNGKey, dataset: Dataset):
1283
- evaluation_loop(rng, dataset, split='valid')
1284
 
1285
  def predict(rng: jax.random.PRNGKey, dataset: Dataset):
1286
- evaluation_loop(rng, dataset, split='test')
1287
 
1288
  input_rng = None
1289
 
@@ -1302,7 +1298,15 @@ def main():
1302
 
1303
  train_metrics = []
1304
 
1305
- train_batches = blockwise_data_loader(input_rng, train_dataset, block_size=training_args.block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, split="train")
 
 
 
 
 
 
 
 
1306
 
1307
  # train
1308
  for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
@@ -1332,10 +1336,10 @@ def main():
1332
  if has_tensorboard and jax.process_index() == 0:
1333
  write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1334
 
1335
- # ======================== Evaluating (inside epoch) ==============================
1336
 
1337
  if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
1338
- evaluate(input_rng, eval_dataset)
1339
  save_results(epoch, cur_step)
1340
 
1341
  # ======================== Epoch End ==============================
@@ -1355,16 +1359,20 @@ def main():
1355
  # ======================== Evaluating (after each epoch) ==============================
1356
 
1357
  if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
1358
- evaluate(input_rng, eval_dataset)
1359
  save_results(epoch, cur_step)
1360
 
1361
- # ======================== Prediction loop ==============================
1362
 
1363
  # Create sampling rng
1364
  if input_rng is None:
1365
  rng, input_rng = jax.random.split(rng)
1366
 
1367
- # run prediction after each epoch (if not done during training)
 
 
 
 
1368
  if training_args.do_predict:
1369
  predict(input_rng, predict_dataset)
1370
 
 
957
 
958
  _ds = ds.select(selected_indices)
959
 
 
 
 
 
 
 
960
  _ds = _ds.map(
961
  feature_extraction_fn,
962
  batched=True,
 
965
  load_from_cache_file=not data_args.overwrite_cache,
966
  features=features,
967
  keep_in_memory=keep_in_memory,
968
+ desc=f"Running feature extraction on {split} dataset".replace(" ", " "),
969
  )
970
  _ds = _ds.with_format("numpy")
971
 
 
1177
  commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
1178
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1179
 
1180
+ def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, split: str = "eval", ckpt_dir: str = "", is_prediction=False):
 
 
 
1181
 
1182
+ logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
1183
 
1184
  metrics = []
1185
  preds = []
1186
  labels = []
1187
 
1188
+ batches = blockwise_data_loader(
1189
+ rng,
1190
+ dataset,
1191
+ block_size=training_args.block_size,
1192
+ batch_size=eval_batch_size,
1193
+ keep_in_memory=False,
1194
+ shuffle=False,
1195
+ split="prediction" if is_prediction else "validation",
1196
+ )
1197
  steps = len(dataset) // eval_batch_size
1198
+ for _ in tqdm(range(steps), desc=f"{'Predicting' if is_prediction else 'Evaluating'}...", position=2, leave=False):
1199
  # Model forward
1200
  batch = next(batches)
1201
  _labels = batch.get("labels", None)
1202
+ if not is_prediction and _labels is None:
1203
+ raise ValueError("Evaluation requires the validation dataset to have `labels`")
1204
 
1205
  if _labels is not None:
1206
  _metrics = p_eval_step(state.params, batch)
 
1225
  if labels:
1226
  rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
1227
  metrics.update(rouge_metrics)
1228
+ rouge_desc = " ".join([f"{'Predict' if is_prediction else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
1229
  for pred, label in zip(decoded_preds, decoded_labels):
1230
  pred = pred.replace("\n", " ")
1231
  label = label.replace("\n", " ")
 
1242
 
1243
  if metrics:
1244
  # Print metrics and update progress bar
1245
+ desc = f"{'Predict' if is_prediction else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
1246
+ if not is_prediction:
1247
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
1248
  epochs.write(desc)
1249
  epochs.desc = desc
 
1251
 
1252
  if jax.process_index() == 0:
1253
 
1254
+ if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
1255
+ os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
 
 
 
1256
 
1257
  if metrics:
1258
 
 
1267
  fp.write(desc + '\n')
1268
 
1269
  # Save metrics
1270
+ if has_tensorboard and is_prediction:
1271
  write_metric(summary_writer, split, metrics, cur_step)
1272
 
1273
  # Save generations
 
1275
  with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
1276
  json.dump(generations, fp, ensure_ascii=False, indent=4)
1277
 
1278
+ def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
1279
+ evaluation_loop(rng, dataset, split='eval', ckpt_dir=ckpt_dir)
1280
 
1281
  def predict(rng: jax.random.PRNGKey, dataset: Dataset):
1282
+ evaluation_loop(rng, dataset, split='test', is_prediction=True)
1283
 
1284
  input_rng = None
1285
 
 
1298
 
1299
  train_metrics = []
1300
 
1301
+ train_batches = blockwise_data_loader(
1302
+ input_rng,
1303
+ train_dataset,
1304
+ block_size=training_args.block_size,
1305
+ batch_size=train_batch_size,
1306
+ keep_in_memory=True,
1307
+ shuffle=True,
1308
+ split="train"
1309
+ )
1310
 
1311
  # train
1312
  for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
 
1336
  if has_tensorboard and jax.process_index() == 0:
1337
  write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1338
 
1339
+ # ======================== Evaluating (inside an epoch) ==============================
1340
 
1341
  if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
1342
+ evaluate(input_rng, eval_dataset, ckpt_dir=f"ckpt_epoch_{epoch + 1}_step_{cur_step}")
1343
  save_results(epoch, cur_step)
1344
 
1345
  # ======================== Epoch End ==============================
 
1359
  # ======================== Evaluating (after each epoch) ==============================
1360
 
1361
  if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
1362
+ evaluate(input_rng, eval_dataset, ckpt_dir=f"ckpt_epoch_{epoch + 1}_step_{cur_step}")
1363
  save_results(epoch, cur_step)
1364
 
1365
+ # ======================== Evaluating | Predicting ==============================
1366
 
1367
  # Create sampling rng
1368
  if input_rng is None:
1369
  rng, input_rng = jax.random.split(rng)
1370
 
1371
+ # run evaluation without training
1372
+ if training_args.do_eval and not training_args.do_train:
1373
+ evaluate(input_rng, eval_dataset)
1374
+
1375
+ # run prediction after (or without) training
1376
  if training_args.do_predict:
1377
  predict(input_rng, predict_dataset)
1378