ydshieh
commited on
Commit
·
7f250ee
1
Parent(s):
fa45968
improve code
Browse files- run_image_captioning_flax.py +42 -34
run_image_captioning_flax.py
CHANGED
|
@@ -957,12 +957,6 @@ def main():
|
|
| 957 |
|
| 958 |
_ds = ds.select(selected_indices)
|
| 959 |
|
| 960 |
-
names = {
|
| 961 |
-
"train": "train",
|
| 962 |
-
"valid": "validation",
|
| 963 |
-
"test": "prediction",
|
| 964 |
-
}
|
| 965 |
-
|
| 966 |
_ds = _ds.map(
|
| 967 |
feature_extraction_fn,
|
| 968 |
batched=True,
|
|
@@ -971,7 +965,7 @@ def main():
|
|
| 971 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 972 |
features=features,
|
| 973 |
keep_in_memory=keep_in_memory,
|
| 974 |
-
desc=f"Running feature extraction on {
|
| 975 |
)
|
| 976 |
_ds = _ds.with_format("numpy")
|
| 977 |
|
|
@@ -1183,25 +1177,30 @@ def main():
|
|
| 1183 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
| 1184 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
| 1185 |
|
| 1186 |
-
def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, split: str):
|
| 1187 |
-
|
| 1188 |
-
if split not in ["valid", "test"]:
|
| 1189 |
-
raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
|
| 1190 |
|
| 1191 |
-
logger.info(f"*** {'Predict' if
|
| 1192 |
|
| 1193 |
metrics = []
|
| 1194 |
preds = []
|
| 1195 |
labels = []
|
| 1196 |
|
| 1197 |
-
batches = blockwise_data_loader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1198 |
steps = len(dataset) // eval_batch_size
|
| 1199 |
-
for _ in tqdm(range(steps), desc=f"{'Predicting' if
|
| 1200 |
# Model forward
|
| 1201 |
batch = next(batches)
|
| 1202 |
_labels = batch.get("labels", None)
|
| 1203 |
-
if
|
| 1204 |
-
raise ValueError("
|
| 1205 |
|
| 1206 |
if _labels is not None:
|
| 1207 |
_metrics = p_eval_step(state.params, batch)
|
|
@@ -1226,7 +1225,7 @@ def main():
|
|
| 1226 |
if labels:
|
| 1227 |
rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
|
| 1228 |
metrics.update(rouge_metrics)
|
| 1229 |
-
rouge_desc = " ".join([f"{'Predict' if
|
| 1230 |
for pred, label in zip(decoded_preds, decoded_labels):
|
| 1231 |
pred = pred.replace("\n", " ")
|
| 1232 |
label = label.replace("\n", " ")
|
|
@@ -1243,8 +1242,8 @@ def main():
|
|
| 1243 |
|
| 1244 |
if metrics:
|
| 1245 |
# Print metrics and update progress bar
|
| 1246 |
-
desc = f"{'Predict' if
|
| 1247 |
-
if
|
| 1248 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
|
| 1249 |
epochs.write(desc)
|
| 1250 |
epochs.desc = desc
|
|
@@ -1252,11 +1251,8 @@ def main():
|
|
| 1252 |
|
| 1253 |
if jax.process_index() == 0:
|
| 1254 |
|
| 1255 |
-
|
| 1256 |
-
|
| 1257 |
-
ckpt_dir = f'ckpt_epoch_{epoch + 1}_step_{cur_step}'
|
| 1258 |
-
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
|
| 1259 |
-
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
|
| 1260 |
|
| 1261 |
if metrics:
|
| 1262 |
|
|
@@ -1271,7 +1267,7 @@ def main():
|
|
| 1271 |
fp.write(desc + '\n')
|
| 1272 |
|
| 1273 |
# Save metrics
|
| 1274 |
-
if has_tensorboard and
|
| 1275 |
write_metric(summary_writer, split, metrics, cur_step)
|
| 1276 |
|
| 1277 |
# Save generations
|
|
@@ -1279,11 +1275,11 @@ def main():
|
|
| 1279 |
with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
|
| 1280 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
| 1281 |
|
| 1282 |
-
def evaluate(rng: jax.random.PRNGKey, dataset: Dataset):
|
| 1283 |
-
evaluation_loop(rng, dataset, split='
|
| 1284 |
|
| 1285 |
def predict(rng: jax.random.PRNGKey, dataset: Dataset):
|
| 1286 |
-
evaluation_loop(rng, dataset, split='test')
|
| 1287 |
|
| 1288 |
input_rng = None
|
| 1289 |
|
|
@@ -1302,7 +1298,15 @@ def main():
|
|
| 1302 |
|
| 1303 |
train_metrics = []
|
| 1304 |
|
| 1305 |
-
train_batches = blockwise_data_loader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1306 |
|
| 1307 |
# train
|
| 1308 |
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
|
@@ -1332,10 +1336,10 @@ def main():
|
|
| 1332 |
if has_tensorboard and jax.process_index() == 0:
|
| 1333 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
| 1334 |
|
| 1335 |
-
# ======================== Evaluating (inside epoch) ==============================
|
| 1336 |
|
| 1337 |
if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
|
| 1338 |
-
evaluate(input_rng, eval_dataset)
|
| 1339 |
save_results(epoch, cur_step)
|
| 1340 |
|
| 1341 |
# ======================== Epoch End ==============================
|
|
@@ -1355,16 +1359,20 @@ def main():
|
|
| 1355 |
# ======================== Evaluating (after each epoch) ==============================
|
| 1356 |
|
| 1357 |
if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
|
| 1358 |
-
evaluate(input_rng, eval_dataset)
|
| 1359 |
save_results(epoch, cur_step)
|
| 1360 |
|
| 1361 |
-
# ========================
|
| 1362 |
|
| 1363 |
# Create sampling rng
|
| 1364 |
if input_rng is None:
|
| 1365 |
rng, input_rng = jax.random.split(rng)
|
| 1366 |
|
| 1367 |
-
# run
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1368 |
if training_args.do_predict:
|
| 1369 |
predict(input_rng, predict_dataset)
|
| 1370 |
|
|
|
|
| 957 |
|
| 958 |
_ds = ds.select(selected_indices)
|
| 959 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
_ds = _ds.map(
|
| 961 |
feature_extraction_fn,
|
| 962 |
batched=True,
|
|
|
|
| 965 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 966 |
features=features,
|
| 967 |
keep_in_memory=keep_in_memory,
|
| 968 |
+
desc=f"Running feature extraction on {split} dataset".replace(" ", " "),
|
| 969 |
)
|
| 970 |
_ds = _ds.with_format("numpy")
|
| 971 |
|
|
|
|
| 1177 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
| 1178 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
| 1179 |
|
| 1180 |
+
def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, split: str = "eval", ckpt_dir: str = "", is_prediction=False):
|
|
|
|
|
|
|
|
|
|
| 1181 |
|
| 1182 |
+
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
| 1183 |
|
| 1184 |
metrics = []
|
| 1185 |
preds = []
|
| 1186 |
labels = []
|
| 1187 |
|
| 1188 |
+
batches = blockwise_data_loader(
|
| 1189 |
+
rng,
|
| 1190 |
+
dataset,
|
| 1191 |
+
block_size=training_args.block_size,
|
| 1192 |
+
batch_size=eval_batch_size,
|
| 1193 |
+
keep_in_memory=False,
|
| 1194 |
+
shuffle=False,
|
| 1195 |
+
split="prediction" if is_prediction else "validation",
|
| 1196 |
+
)
|
| 1197 |
steps = len(dataset) // eval_batch_size
|
| 1198 |
+
for _ in tqdm(range(steps), desc=f"{'Predicting' if is_prediction else 'Evaluating'}...", position=2, leave=False):
|
| 1199 |
# Model forward
|
| 1200 |
batch = next(batches)
|
| 1201 |
_labels = batch.get("labels", None)
|
| 1202 |
+
if not is_prediction and _labels is None:
|
| 1203 |
+
raise ValueError("Evaluation requires the validation dataset to have `labels`")
|
| 1204 |
|
| 1205 |
if _labels is not None:
|
| 1206 |
_metrics = p_eval_step(state.params, batch)
|
|
|
|
| 1225 |
if labels:
|
| 1226 |
rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
|
| 1227 |
metrics.update(rouge_metrics)
|
| 1228 |
+
rouge_desc = " ".join([f"{'Predict' if is_prediction else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
|
| 1229 |
for pred, label in zip(decoded_preds, decoded_labels):
|
| 1230 |
pred = pred.replace("\n", " ")
|
| 1231 |
label = label.replace("\n", " ")
|
|
|
|
| 1242 |
|
| 1243 |
if metrics:
|
| 1244 |
# Print metrics and update progress bar
|
| 1245 |
+
desc = f"{'Predict' if is_prediction else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
|
| 1246 |
+
if not is_prediction:
|
| 1247 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
|
| 1248 |
epochs.write(desc)
|
| 1249 |
epochs.desc = desc
|
|
|
|
| 1251 |
|
| 1252 |
if jax.process_index() == 0:
|
| 1253 |
|
| 1254 |
+
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
|
| 1255 |
+
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 1256 |
|
| 1257 |
if metrics:
|
| 1258 |
|
|
|
|
| 1267 |
fp.write(desc + '\n')
|
| 1268 |
|
| 1269 |
# Save metrics
|
| 1270 |
+
if has_tensorboard and is_prediction:
|
| 1271 |
write_metric(summary_writer, split, metrics, cur_step)
|
| 1272 |
|
| 1273 |
# Save generations
|
|
|
|
| 1275 |
with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
|
| 1276 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
| 1277 |
|
| 1278 |
+
def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
|
| 1279 |
+
evaluation_loop(rng, dataset, split='eval', ckpt_dir=ckpt_dir)
|
| 1280 |
|
| 1281 |
def predict(rng: jax.random.PRNGKey, dataset: Dataset):
|
| 1282 |
+
evaluation_loop(rng, dataset, split='test', is_prediction=True)
|
| 1283 |
|
| 1284 |
input_rng = None
|
| 1285 |
|
|
|
|
| 1298 |
|
| 1299 |
train_metrics = []
|
| 1300 |
|
| 1301 |
+
train_batches = blockwise_data_loader(
|
| 1302 |
+
input_rng,
|
| 1303 |
+
train_dataset,
|
| 1304 |
+
block_size=training_args.block_size,
|
| 1305 |
+
batch_size=train_batch_size,
|
| 1306 |
+
keep_in_memory=True,
|
| 1307 |
+
shuffle=True,
|
| 1308 |
+
split="train"
|
| 1309 |
+
)
|
| 1310 |
|
| 1311 |
# train
|
| 1312 |
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
|
|
|
| 1336 |
if has_tensorboard and jax.process_index() == 0:
|
| 1337 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
| 1338 |
|
| 1339 |
+
# ======================== Evaluating (inside an epoch) ==============================
|
| 1340 |
|
| 1341 |
if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
|
| 1342 |
+
evaluate(input_rng, eval_dataset, ckpt_dir=f"ckpt_epoch_{epoch + 1}_step_{cur_step}")
|
| 1343 |
save_results(epoch, cur_step)
|
| 1344 |
|
| 1345 |
# ======================== Epoch End ==============================
|
|
|
|
| 1359 |
# ======================== Evaluating (after each epoch) ==============================
|
| 1360 |
|
| 1361 |
if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
|
| 1362 |
+
evaluate(input_rng, eval_dataset, ckpt_dir=f"ckpt_epoch_{epoch + 1}_step_{cur_step}")
|
| 1363 |
save_results(epoch, cur_step)
|
| 1364 |
|
| 1365 |
+
# ======================== Evaluating | Predicting ==============================
|
| 1366 |
|
| 1367 |
# Create sampling rng
|
| 1368 |
if input_rng is None:
|
| 1369 |
rng, input_rng = jax.random.split(rng)
|
| 1370 |
|
| 1371 |
+
# run evaluation without training
|
| 1372 |
+
if training_args.do_eval and not training_args.do_train:
|
| 1373 |
+
evaluate(input_rng, eval_dataset)
|
| 1374 |
+
|
| 1375 |
+
# run prediction after (or without) training
|
| 1376 |
if training_args.do_predict:
|
| 1377 |
predict(input_rng, predict_dataset)
|
| 1378 |
|