ydshieh
commited on
Commit
·
10c0ce9
1
Parent(s):
cff52cf
clean up
Browse files- run_image_captioning_flax.py +50 -40
run_image_captioning_flax.py
CHANGED
@@ -123,11 +123,9 @@ class TrainingArguments:
|
|
123 |
label_smoothing_factor: float = field(
|
124 |
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
125 |
)
|
126 |
-
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
127 |
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
128 |
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
129 |
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
130 |
-
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
131 |
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
132 |
dataloader_drop_last: bool = field(
|
133 |
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
@@ -1155,24 +1153,24 @@ def main():
|
|
1155 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
1156 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
1157 |
|
1158 |
-
def
|
1159 |
|
1160 |
-
if
|
1161 |
-
raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {
|
1162 |
|
1163 |
-
logger.info(f"*** {'Predict' if
|
1164 |
|
1165 |
metrics = []
|
1166 |
preds = []
|
1167 |
labels = []
|
1168 |
|
1169 |
-
batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=
|
1170 |
steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
|
1171 |
-
for _ in tqdm(range(steps), desc=f"{'Predicting' if
|
1172 |
# Model forward
|
1173 |
batch = next(batches)
|
1174 |
_labels = batch.get("labels", None)
|
1175 |
-
if
|
1176 |
raise ValueError("Validation dataset requires `labels`")
|
1177 |
|
1178 |
if _labels is not None:
|
@@ -1198,7 +1196,7 @@ def main():
|
|
1198 |
if labels:
|
1199 |
rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
|
1200 |
metrics.update(rouge_metrics)
|
1201 |
-
rouge_desc = " ".join([f"{'Predict' if
|
1202 |
for pred, label in zip(decoded_preds, decoded_labels):
|
1203 |
pred = pred.replace("\n", " ")
|
1204 |
label = label.replace("\n", " ")
|
@@ -1215,8 +1213,8 @@ def main():
|
|
1215 |
|
1216 |
if metrics:
|
1217 |
# Print metrics and update progress bar
|
1218 |
-
desc = f"{'Predict' if
|
1219 |
-
if
|
1220 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
|
1221 |
epochs.write(desc)
|
1222 |
epochs.desc = desc
|
@@ -1225,7 +1223,7 @@ def main():
|
|
1225 |
if jax.process_index() == 0:
|
1226 |
|
1227 |
ckpt_dir = ""
|
1228 |
-
if
|
1229 |
ckpt_dir = f'ckpt_epoch_{epoch + 1}_step_{cur_step}'
|
1230 |
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
|
1231 |
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
|
@@ -1233,9 +1231,9 @@ def main():
|
|
1233 |
if metrics:
|
1234 |
|
1235 |
# save final metrics in json
|
1236 |
-
metrics = {f"{
|
1237 |
-
|
1238 |
-
with open(
|
1239 |
json.dump(metrics, f, indent=4, sort_keys=True)
|
1240 |
|
1241 |
# Update report
|
@@ -1243,14 +1241,20 @@ def main():
|
|
1243 |
fp.write(desc + '\n')
|
1244 |
|
1245 |
# Save metrics
|
1246 |
-
if has_tensorboard and
|
1247 |
-
write_metric(summary_writer,
|
1248 |
|
1249 |
# Save generations
|
1250 |
if generations:
|
1251 |
-
with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{
|
1252 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
1253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1254 |
input_rng = None
|
1255 |
|
1256 |
if training_args.do_train:
|
@@ -1280,15 +1284,17 @@ def main():
|
|
1280 |
train_metrics.append(train_metric)
|
1281 |
train_time += time.time() - batch_start
|
1282 |
|
1283 |
-
|
|
|
|
|
|
|
|
|
1284 |
|
1285 |
-
|
|
|
1286 |
|
1287 |
-
_train_metric = unreplicate(train_metric)
|
1288 |
-
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
|
1289 |
-
epochs.desc = desc
|
1290 |
-
epochs.write(desc)
|
1291 |
logger.info(desc)
|
|
|
1292 |
with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
|
1293 |
fp.write(desc + '\n')
|
1294 |
|
@@ -1296,34 +1302,38 @@ def main():
|
|
1296 |
if has_tensorboard and jax.process_index() == 0:
|
1297 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
1298 |
|
1299 |
-
|
|
|
|
|
|
|
1300 |
|
1301 |
-
|
1302 |
-
run_eval_or_test(input_rng, eval_dataset, name="valid", is_inside_training=True)
|
1303 |
|
1304 |
-
|
|
|
1305 |
|
1306 |
-
|
1307 |
-
if training_args.do_predict and training_args.do_predict_during_training and training_args.do_predict_after_evaluation:
|
1308 |
-
run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
|
1309 |
|
1310 |
-
|
|
|
1311 |
|
1312 |
-
|
1313 |
-
|
|
|
1314 |
|
1315 |
-
|
1316 |
-
|
1317 |
-
run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
|
1318 |
save_results(epoch + 1, cur_step)
|
1319 |
|
|
|
|
|
1320 |
# Create sampling rng
|
1321 |
if input_rng is None:
|
1322 |
rng, input_rng = jax.random.split(rng)
|
1323 |
|
1324 |
# run prediction after each epoch (if not done during training)
|
1325 |
-
if training_args.do_predict
|
1326 |
-
|
1327 |
|
1328 |
|
1329 |
if __name__ == "__main__":
|
|
|
123 |
label_smoothing_factor: float = field(
|
124 |
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
125 |
)
|
|
|
126 |
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
127 |
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
128 |
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
|
|
129 |
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
130 |
dataloader_drop_last: bool = field(
|
131 |
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
|
|
1153 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
|
1154 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
1155 |
|
1156 |
+
def evaluation_loop(rng, dataset, split):
|
1157 |
|
1158 |
+
if split not in ["valid", "test"]:
|
1159 |
+
raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {split} instead.")
|
1160 |
|
1161 |
+
logger.info(f"*** {'Predict' if split == 'test' else 'Evaluate'} ***")
|
1162 |
|
1163 |
metrics = []
|
1164 |
preds = []
|
1165 |
labels = []
|
1166 |
|
1167 |
+
batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
|
1168 |
steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
|
1169 |
+
for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
|
1170 |
# Model forward
|
1171 |
batch = next(batches)
|
1172 |
_labels = batch.get("labels", None)
|
1173 |
+
if split == "valid" and _labels is None:
|
1174 |
raise ValueError("Validation dataset requires `labels`")
|
1175 |
|
1176 |
if _labels is not None:
|
|
|
1196 |
if labels:
|
1197 |
rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
|
1198 |
metrics.update(rouge_metrics)
|
1199 |
+
rouge_desc = " ".join([f"{'Predict' if split == 'test' else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
|
1200 |
for pred, label in zip(decoded_preds, decoded_labels):
|
1201 |
pred = pred.replace("\n", " ")
|
1202 |
label = label.replace("\n", " ")
|
|
|
1213 |
|
1214 |
if metrics:
|
1215 |
# Print metrics and update progress bar
|
1216 |
+
desc = f"{'Predict' if split == 'test' else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
|
1217 |
+
if split == "valid":
|
1218 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
|
1219 |
epochs.write(desc)
|
1220 |
epochs.desc = desc
|
|
|
1223 |
if jax.process_index() == 0:
|
1224 |
|
1225 |
ckpt_dir = ""
|
1226 |
+
if split == "valid":
|
1227 |
ckpt_dir = f'ckpt_epoch_{epoch + 1}_step_{cur_step}'
|
1228 |
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
|
1229 |
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
|
|
|
1231 |
if metrics:
|
1232 |
|
1233 |
# save final metrics in json
|
1234 |
+
metrics = {f"{split}_{metric_name}": round(value.item(), 6) for metric_name, value in metrics.items()}
|
1235 |
+
_path = os.path.join(training_args.output_dir, ckpt_dir, f"{split}_results.json")
|
1236 |
+
with open(_path, "w") as f:
|
1237 |
json.dump(metrics, f, indent=4, sort_keys=True)
|
1238 |
|
1239 |
# Update report
|
|
|
1241 |
fp.write(desc + '\n')
|
1242 |
|
1243 |
# Save metrics
|
1244 |
+
if has_tensorboard and split == "valid":
|
1245 |
+
write_metric(summary_writer, split, metrics, cur_step)
|
1246 |
|
1247 |
# Save generations
|
1248 |
if generations:
|
1249 |
+
with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{split}.json'), 'w', encoding='UTF-8') as fp:
|
1250 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
1251 |
|
1252 |
+
def evaluate(rng, dataset):
|
1253 |
+
evaluation_loop(rng, dataset, split='eval')
|
1254 |
+
|
1255 |
+
def predict(rng, dataset):
|
1256 |
+
evaluation_loop(rng, dataset, split='test')
|
1257 |
+
|
1258 |
input_rng = None
|
1259 |
|
1260 |
if training_args.do_train:
|
|
|
1284 |
train_metrics.append(train_metric)
|
1285 |
train_time += time.time() - batch_start
|
1286 |
|
1287 |
+
time_per_step = train_time / cur_step
|
1288 |
+
_train_metric = unreplicate(train_metric)
|
1289 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
|
1290 |
+
epochs.desc = desc
|
1291 |
+
epochs.write(desc)
|
1292 |
|
1293 |
+
# log and save info
|
1294 |
+
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
|
1295 |
|
|
|
|
|
|
|
|
|
1296 |
logger.info(desc)
|
1297 |
+
|
1298 |
with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
|
1299 |
fp.write(desc + '\n')
|
1300 |
|
|
|
1302 |
if has_tensorboard and jax.process_index() == 0:
|
1303 |
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
1304 |
|
1305 |
+
# ======================== Evaluating ==============================
|
1306 |
+
if training_args.eval_steps is not None and training_args.eval_steps > 0 and cur_step % training_args.eval_steps == 0:
|
1307 |
+
evaluate(input_rng, eval_dataset)
|
1308 |
+
save_results(epoch + 1, cur_step)
|
1309 |
|
1310 |
+
# ======================== Epoch End ==============================
|
|
|
1311 |
|
1312 |
+
# log and save info
|
1313 |
+
if training_args.logging_steps <= 0:
|
1314 |
|
1315 |
+
logger.info(desc)
|
|
|
|
|
1316 |
|
1317 |
+
with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
|
1318 |
+
fp.write(desc + '\n')
|
1319 |
|
1320 |
+
# Save metrics
|
1321 |
+
if has_tensorboard and jax.process_index() == 0:
|
1322 |
+
write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
|
1323 |
|
1324 |
+
if training_args.eval_steps is None or training_args.eval_steps <= 0:
|
1325 |
+
evaluate(input_rng, eval_dataset)
|
|
|
1326 |
save_results(epoch + 1, cur_step)
|
1327 |
|
1328 |
+
# ======================== Prediction loop ==============================
|
1329 |
+
|
1330 |
# Create sampling rng
|
1331 |
if input_rng is None:
|
1332 |
rng, input_rng = jax.random.split(rng)
|
1333 |
|
1334 |
# run prediction after each epoch (if not done during training)
|
1335 |
+
if training_args.do_predict:
|
1336 |
+
predict(input_rng, predict_dataset)
|
1337 |
|
1338 |
|
1339 |
if __name__ == "__main__":
|