ydshieh
commited on
Commit
·
2c5a28b
1
Parent(s):
16517d8
update to be as a base
Browse files- run_image_captioning_flax_reduced.py +128 -64
run_image_captioning_flax_reduced.py
CHANGED
|
@@ -32,8 +32,8 @@ import datasets
|
|
| 32 |
import nltk # Here to have a nice missing dependency error message early on
|
| 33 |
import numpy as np
|
| 34 |
from datasets import Dataset, load_dataset, load_metric
|
| 35 |
-
from tqdm import tqdm
|
| 36 |
from PIL import Image
|
|
|
|
| 37 |
|
| 38 |
import jax
|
| 39 |
import jax.numpy as jnp
|
|
@@ -47,14 +47,14 @@ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_ke
|
|
| 47 |
from huggingface_hub import Repository
|
| 48 |
from transformers import (
|
| 49 |
CONFIG_MAPPING,
|
| 50 |
-
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
| 51 |
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
|
|
|
| 52 |
AutoConfig,
|
| 53 |
AutoFeatureExtractor,
|
| 54 |
AutoTokenizer,
|
|
|
|
| 55 |
HfArgumentParser,
|
| 56 |
is_tensorboard_available,
|
| 57 |
-
FlaxAutoModelForVision2Seq,
|
| 58 |
)
|
| 59 |
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
| 60 |
|
|
@@ -113,8 +113,7 @@ class TrainingArguments:
|
|
| 113 |
per_device_eval_batch_size: int = field(
|
| 114 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
| 115 |
)
|
| 116 |
-
_block_size_doc =
|
| 117 |
-
"""
|
| 118 |
The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
|
| 119 |
cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
|
| 120 |
good option if your disk space is large enough to store the whole processed dataset.
|
|
@@ -124,10 +123,7 @@ class TrainingArguments:
|
|
| 124 |
`batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
|
| 125 |
dataset is large.
|
| 126 |
"""
|
| 127 |
-
block_size: int = field(
|
| 128 |
-
default=0,
|
| 129 |
-
metadata={"help": _block_size_doc}
|
| 130 |
-
)
|
| 131 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
| 132 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
| 133 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
@@ -197,16 +193,21 @@ class ModelArguments:
|
|
| 197 |
},
|
| 198 |
)
|
| 199 |
model_type: Optional[str] = field(
|
| 200 |
-
default=
|
| 201 |
-
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}
|
| 202 |
)
|
| 203 |
encoder_model_type: Optional[str] = field(
|
| 204 |
default=None,
|
| 205 |
-
metadata={
|
|
|
|
|
|
|
| 206 |
)
|
| 207 |
decoder_model_type: Optional[str] = field(
|
| 208 |
default=None,
|
| 209 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
| 210 |
)
|
| 211 |
config_name: Optional[str] = field(
|
| 212 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
@@ -218,10 +219,12 @@ class ModelArguments:
|
|
| 218 |
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
|
| 219 |
)
|
| 220 |
feature_extractor_name: Optional[str] = field(
|
| 221 |
-
default=None,
|
|
|
|
| 222 |
)
|
| 223 |
tokenizer_name: Optional[str] = field(
|
| 224 |
-
default=None,
|
|
|
|
| 225 |
)
|
| 226 |
cache_dir: Optional[str] = field(
|
| 227 |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
|
@@ -505,7 +508,7 @@ def main():
|
|
| 505 |
# Use specified `model_type` (default to `vision-encoder-decoder`)
|
| 506 |
else:
|
| 507 |
|
| 508 |
-
if
|
| 509 |
raise ValueError(
|
| 510 |
f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
|
| 511 |
)
|
|
@@ -516,29 +519,41 @@ def main():
|
|
| 516 |
|
| 517 |
# Use explicit specified encoder config
|
| 518 |
if model_args.encoder_config_name:
|
| 519 |
-
encoder_config = AutoConfig.from_pretrained(
|
|
|
|
|
|
|
| 520 |
# Use pretrained encoder model's config
|
| 521 |
elif model_args.encoder_model_name_or_path:
|
| 522 |
-
encoder_config = AutoConfig.from_pretrained(
|
|
|
|
|
|
|
| 523 |
# Use specified encoder model type
|
| 524 |
elif model_args.encoder_model_type:
|
| 525 |
encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
|
| 526 |
logger.warning("You are instantiating a new config instance from scratch for the encoder.")
|
| 527 |
else:
|
| 528 |
-
raise ValueError(
|
|
|
|
|
|
|
| 529 |
|
| 530 |
# Use explicit specified decoder config
|
| 531 |
if model_args.decoder_config_name:
|
| 532 |
-
decoder_config = AutoConfig.from_pretrained(
|
|
|
|
|
|
|
| 533 |
# Use pretrained decoder model's config
|
| 534 |
elif model_args.decoder_model_name_or_path:
|
| 535 |
-
decoder_config = AutoConfig.from_pretrained(
|
|
|
|
|
|
|
| 536 |
# Use specified decoder model type
|
| 537 |
elif model_args.decoder_model_type:
|
| 538 |
decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
|
| 539 |
logger.warning("You are instantiating a new config instance from scratch for the decoder.")
|
| 540 |
else:
|
| 541 |
-
raise ValueError(
|
|
|
|
|
|
|
| 542 |
|
| 543 |
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
| 544 |
decoder_config.is_decoder = True
|
|
@@ -586,7 +601,9 @@ def main():
|
|
| 586 |
)
|
| 587 |
else:
|
| 588 |
# model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
|
| 589 |
-
model = FlaxAutoModelForVision2Seq.from_config(
|
|
|
|
|
|
|
| 590 |
model_class = model.__class__
|
| 591 |
|
| 592 |
# encoder_class = FlaxAutoModel
|
|
@@ -604,10 +621,12 @@ def main():
|
|
| 604 |
model_args.encoder_model_name_or_path,
|
| 605 |
config=config.encoder,
|
| 606 |
seed=training_args.seed,
|
| 607 |
-
dtype=getattr(jnp, model_args.dtype)
|
| 608 |
)
|
| 609 |
else:
|
| 610 |
-
encoder = encoder_class(
|
|
|
|
|
|
|
| 611 |
logger.warning("You are instantiating a new model instance from scratch for the encoder.")
|
| 612 |
|
| 613 |
if model_args.decoder_model_name_or_path:
|
|
@@ -615,10 +634,12 @@ def main():
|
|
| 615 |
model_args.decoder_model_name_or_path,
|
| 616 |
config=config.decoder,
|
| 617 |
seed=training_args.seed,
|
| 618 |
-
dtype=getattr(jnp, model_args.dtype)
|
| 619 |
)
|
| 620 |
else:
|
| 621 |
-
decoder = decoder_class(
|
|
|
|
|
|
|
| 622 |
logger.warning("You are instantiating a new model instance from scratch for the decoder.")
|
| 623 |
|
| 624 |
model = model_class.from_encoder_decoder_pretrained(
|
|
@@ -646,7 +667,8 @@ def main():
|
|
| 646 |
feature_extractor = None
|
| 647 |
if model_args.feature_extractor_name:
|
| 648 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 649 |
-
model_args.feature_extractor_name,
|
|
|
|
| 650 |
)
|
| 651 |
elif model_args.model_name_or_path:
|
| 652 |
try:
|
|
@@ -684,7 +706,9 @@ def main():
|
|
| 684 |
if not tokenizer:
|
| 685 |
if model_args.decoder_model_name_or_path:
|
| 686 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 687 |
-
model_args.decoder_model_name_or_path,
|
|
|
|
|
|
|
| 688 |
)
|
| 689 |
else:
|
| 690 |
raise ValueError(
|
|
@@ -739,9 +763,9 @@ def main():
|
|
| 739 |
for image_file in examples[image_column]:
|
| 740 |
try:
|
| 741 |
image = Image.open(image_file)
|
| 742 |
-
|
| 743 |
bools.append(True)
|
| 744 |
-
except:
|
| 745 |
bools.append(False)
|
| 746 |
|
| 747 |
return bools
|
|
@@ -752,7 +776,7 @@ def main():
|
|
| 752 |
|
| 753 |
captions = []
|
| 754 |
for caption in examples[caption_column]:
|
| 755 |
-
|
| 756 |
|
| 757 |
targets = captions
|
| 758 |
|
|
@@ -795,7 +819,7 @@ def main():
|
|
| 795 |
img = Image.open(image_file)
|
| 796 |
images.append(img)
|
| 797 |
to_keep.append(True)
|
| 798 |
-
except:
|
| 799 |
to_keep.append(False)
|
| 800 |
|
| 801 |
for k, v in examples.items():
|
|
@@ -831,9 +855,11 @@ def main():
|
|
| 831 |
),
|
| 832 |
dtype="float32",
|
| 833 |
),
|
| 834 |
-
"labels": datasets.Sequence(feature=datasets.Value(dtype=
|
| 835 |
-
"decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype=
|
| 836 |
-
"decoder_attention_mask": datasets.Sequence(
|
|
|
|
|
|
|
| 837 |
}
|
| 838 |
)
|
| 839 |
|
|
@@ -909,7 +935,9 @@ def main():
|
|
| 909 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
| 910 |
# instead here.)
|
| 911 |
if not run_feat_ext_at_beginning:
|
| 912 |
-
predict_dataset = predict_dataset.filter(
|
|
|
|
|
|
|
| 913 |
predict_dataset = predict_dataset.map(
|
| 914 |
function=function_kwarg,
|
| 915 |
batched=True,
|
|
@@ -930,7 +958,9 @@ def main():
|
|
| 930 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 931 |
|
| 932 |
if training_args.block_size % train_batch_size > 0:
|
| 933 |
-
raise ValueError(
|
|
|
|
|
|
|
| 934 |
|
| 935 |
if training_args.do_train:
|
| 936 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
@@ -951,13 +981,13 @@ def main():
|
|
| 951 |
test_steps = num_test_examples // eval_batch_size
|
| 952 |
|
| 953 |
def blockwise_data_loader(
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
):
|
| 962 |
"""
|
| 963 |
Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
|
|
@@ -1165,7 +1195,7 @@ def main():
|
|
| 1165 |
|
| 1166 |
def generate_step(params, batch):
|
| 1167 |
model.params = params
|
| 1168 |
-
output_ids = model.generate(batch[
|
| 1169 |
return output_ids.sequences
|
| 1170 |
|
| 1171 |
# Create parallel version of the train and eval step
|
|
@@ -1212,7 +1242,13 @@ def main():
|
|
| 1212 |
if training_args.push_to_hub:
|
| 1213 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
| 1214 |
|
| 1215 |
-
def evaluation_loop(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1216 |
|
| 1217 |
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
| 1218 |
|
|
@@ -1230,7 +1266,9 @@ def main():
|
|
| 1230 |
split="prediction" if is_prediction else "validation",
|
| 1231 |
)
|
| 1232 |
steps = len(dataset) // eval_batch_size
|
| 1233 |
-
for _ in tqdm(
|
|
|
|
|
|
|
| 1234 |
# Model forward
|
| 1235 |
batch = next(batches)
|
| 1236 |
_labels = batch.get("labels", None)
|
|
@@ -1260,7 +1298,12 @@ def main():
|
|
| 1260 |
if labels:
|
| 1261 |
rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
|
| 1262 |
metrics.update(rouge_metrics)
|
| 1263 |
-
rouge_desc = " ".join(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1264 |
for pred, label in zip(decoded_preds, decoded_labels):
|
| 1265 |
pred = pred.replace("\n", " ")
|
| 1266 |
label = label.replace("\n", " ")
|
|
@@ -1293,28 +1336,37 @@ def main():
|
|
| 1293 |
|
| 1294 |
# Save metrics (only for the evaluation/prediction being done along with training)
|
| 1295 |
if has_tensorboard and training_args.do_train:
|
| 1296 |
-
write_metric(
|
|
|
|
|
|
|
| 1297 |
|
| 1298 |
# save final metrics in json
|
| 1299 |
-
metrics = {
|
|
|
|
|
|
|
|
|
|
| 1300 |
_path = os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_results.json")
|
| 1301 |
with open(_path, "w") as f:
|
| 1302 |
json.dump(metrics, f, indent=4, sort_keys=True)
|
| 1303 |
|
| 1304 |
# Update report
|
| 1305 |
-
with open(os.path.join(training_args.output_dir,
|
| 1306 |
-
fp.write(desc +
|
| 1307 |
|
| 1308 |
# Save generations
|
| 1309 |
if generations:
|
| 1310 |
-
with open(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1311 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
| 1312 |
|
| 1313 |
def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
|
| 1314 |
-
evaluation_loop(rng, dataset, metric_key_prefix=
|
| 1315 |
|
| 1316 |
def predict(rng: jax.random.PRNGKey, dataset: Dataset):
|
| 1317 |
-
evaluation_loop(rng, dataset, metric_key_prefix=
|
| 1318 |
|
| 1319 |
input_rng = None
|
| 1320 |
|
|
@@ -1340,7 +1392,7 @@ def main():
|
|
| 1340 |
batch_size=train_batch_size,
|
| 1341 |
keep_in_memory=True,
|
| 1342 |
shuffle=True,
|
| 1343 |
-
split="train"
|
| 1344 |
)
|
| 1345 |
|
| 1346 |
# train
|
|
@@ -1364,16 +1416,26 @@ def main():
|
|
| 1364 |
|
| 1365 |
logger.info(desc)
|
| 1366 |
|
| 1367 |
-
with open(os.path.join(training_args.output_dir,
|
| 1368 |
-
fp.write(desc +
|
| 1369 |
|
| 1370 |
# Save metrics
|
| 1371 |
if has_tensorboard and jax.process_index() == 0:
|
| 1372 |
-
write_metric(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1373 |
|
| 1374 |
# ======================== Evaluating (inside an epoch) ==============================
|
| 1375 |
|
| 1376 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1377 |
ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
|
| 1378 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
|
| 1379 |
evaluate(input_rng, eval_dataset, ckpt_dir)
|
|
@@ -1386,12 +1448,14 @@ def main():
|
|
| 1386 |
|
| 1387 |
logger.info(desc)
|
| 1388 |
|
| 1389 |
-
with open(os.path.join(training_args.output_dir,
|
| 1390 |
-
fp.write(desc +
|
| 1391 |
|
| 1392 |
# Save metrics
|
| 1393 |
if has_tensorboard and jax.process_index() == 0:
|
| 1394 |
-
write_metric(
|
|
|
|
|
|
|
| 1395 |
|
| 1396 |
# ======================== Evaluating (after each epoch) ==============================
|
| 1397 |
|
|
|
|
| 32 |
import nltk # Here to have a nice missing dependency error message early on
|
| 33 |
import numpy as np
|
| 34 |
from datasets import Dataset, load_dataset, load_metric
|
|
|
|
| 35 |
from PIL import Image
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
|
| 38 |
import jax
|
| 39 |
import jax.numpy as jnp
|
|
|
|
| 47 |
from huggingface_hub import Repository
|
| 48 |
from transformers import (
|
| 49 |
CONFIG_MAPPING,
|
|
|
|
| 50 |
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
| 51 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
| 52 |
AutoConfig,
|
| 53 |
AutoFeatureExtractor,
|
| 54 |
AutoTokenizer,
|
| 55 |
+
FlaxAutoModelForVision2Seq,
|
| 56 |
HfArgumentParser,
|
| 57 |
is_tensorboard_available,
|
|
|
|
| 58 |
)
|
| 59 |
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
| 60 |
|
|
|
|
| 113 |
per_device_eval_batch_size: int = field(
|
| 114 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
| 115 |
)
|
| 116 |
+
_block_size_doc = """
|
|
|
|
| 117 |
The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
|
| 118 |
cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
|
| 119 |
good option if your disk space is large enough to store the whole processed dataset.
|
|
|
|
| 123 |
`batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
|
| 124 |
dataset is large.
|
| 125 |
"""
|
| 126 |
+
block_size: int = field(default=0, metadata={"help": _block_size_doc})
|
|
|
|
|
|
|
|
|
|
| 127 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
| 128 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
| 129 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
|
|
| 193 |
},
|
| 194 |
)
|
| 195 |
model_type: Optional[str] = field(
|
| 196 |
+
default="vision-encoder-decoder",
|
| 197 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
| 198 |
)
|
| 199 |
encoder_model_type: Optional[str] = field(
|
| 200 |
default=None,
|
| 201 |
+
metadata={
|
| 202 |
+
"help": "If training from scratch, pass a vision encoder model type from the library. For example, 'vit'"
|
| 203 |
+
},
|
| 204 |
)
|
| 205 |
decoder_model_type: Optional[str] = field(
|
| 206 |
default=None,
|
| 207 |
+
metadata={
|
| 208 |
+
"help": "If training from scratch, pass a decoder model type from the list: "
|
| 209 |
+
+ ", ".join(DECODER_MODEL_TYPES)
|
| 210 |
+
},
|
| 211 |
)
|
| 212 |
config_name: Optional[str] = field(
|
| 213 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
|
|
| 219 |
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
|
| 220 |
)
|
| 221 |
feature_extractor_name: Optional[str] = field(
|
| 222 |
+
default=None,
|
| 223 |
+
metadata={"help": "Pretrained encoder feature extractor_name or path if not the same as encoder_model_name"},
|
| 224 |
)
|
| 225 |
tokenizer_name: Optional[str] = field(
|
| 226 |
+
default=None,
|
| 227 |
+
metadata={"help": "Pretrained decoder tokenizer name or path if not the same as decoder_model_name"},
|
| 228 |
)
|
| 229 |
cache_dir: Optional[str] = field(
|
| 230 |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
|
|
|
| 508 |
# Use specified `model_type` (default to `vision-encoder-decoder`)
|
| 509 |
else:
|
| 510 |
|
| 511 |
+
if model_args.model_type not in MODEL_TYPES:
|
| 512 |
raise ValueError(
|
| 513 |
f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
|
| 514 |
)
|
|
|
|
| 519 |
|
| 520 |
# Use explicit specified encoder config
|
| 521 |
if model_args.encoder_config_name:
|
| 522 |
+
encoder_config = AutoConfig.from_pretrained(
|
| 523 |
+
model_args.encoder_config_name, cache_dir=encoder_cache_dir
|
| 524 |
+
)
|
| 525 |
# Use pretrained encoder model's config
|
| 526 |
elif model_args.encoder_model_name_or_path:
|
| 527 |
+
encoder_config = AutoConfig.from_pretrained(
|
| 528 |
+
model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir
|
| 529 |
+
)
|
| 530 |
# Use specified encoder model type
|
| 531 |
elif model_args.encoder_model_type:
|
| 532 |
encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
|
| 533 |
logger.warning("You are instantiating a new config instance from scratch for the encoder.")
|
| 534 |
else:
|
| 535 |
+
raise ValueError(
|
| 536 |
+
"Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required."
|
| 537 |
+
)
|
| 538 |
|
| 539 |
# Use explicit specified decoder config
|
| 540 |
if model_args.decoder_config_name:
|
| 541 |
+
decoder_config = AutoConfig.from_pretrained(
|
| 542 |
+
model_args.decoder_config_name, cache_dir=decoder_cache_dir
|
| 543 |
+
)
|
| 544 |
# Use pretrained decoder model's config
|
| 545 |
elif model_args.decoder_model_name_or_path:
|
| 546 |
+
decoder_config = AutoConfig.from_pretrained(
|
| 547 |
+
model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir
|
| 548 |
+
)
|
| 549 |
# Use specified decoder model type
|
| 550 |
elif model_args.decoder_model_type:
|
| 551 |
decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
|
| 552 |
logger.warning("You are instantiating a new config instance from scratch for the decoder.")
|
| 553 |
else:
|
| 554 |
+
raise ValueError(
|
| 555 |
+
"Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required."
|
| 556 |
+
)
|
| 557 |
|
| 558 |
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
| 559 |
decoder_config.is_decoder = True
|
|
|
|
| 601 |
)
|
| 602 |
else:
|
| 603 |
# model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
|
| 604 |
+
model = FlaxAutoModelForVision2Seq.from_config(
|
| 605 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 606 |
+
)
|
| 607 |
model_class = model.__class__
|
| 608 |
|
| 609 |
# encoder_class = FlaxAutoModel
|
|
|
|
| 621 |
model_args.encoder_model_name_or_path,
|
| 622 |
config=config.encoder,
|
| 623 |
seed=training_args.seed,
|
| 624 |
+
dtype=getattr(jnp, model_args.dtype),
|
| 625 |
)
|
| 626 |
else:
|
| 627 |
+
encoder = encoder_class(
|
| 628 |
+
config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 629 |
+
)
|
| 630 |
logger.warning("You are instantiating a new model instance from scratch for the encoder.")
|
| 631 |
|
| 632 |
if model_args.decoder_model_name_or_path:
|
|
|
|
| 634 |
model_args.decoder_model_name_or_path,
|
| 635 |
config=config.decoder,
|
| 636 |
seed=training_args.seed,
|
| 637 |
+
dtype=getattr(jnp, model_args.dtype),
|
| 638 |
)
|
| 639 |
else:
|
| 640 |
+
decoder = decoder_class(
|
| 641 |
+
config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 642 |
+
)
|
| 643 |
logger.warning("You are instantiating a new model instance from scratch for the decoder.")
|
| 644 |
|
| 645 |
model = model_class.from_encoder_decoder_pretrained(
|
|
|
|
| 667 |
feature_extractor = None
|
| 668 |
if model_args.feature_extractor_name:
|
| 669 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 670 |
+
model_args.feature_extractor_name,
|
| 671 |
+
cache_dir=model_args.cache_dir,
|
| 672 |
)
|
| 673 |
elif model_args.model_name_or_path:
|
| 674 |
try:
|
|
|
|
| 706 |
if not tokenizer:
|
| 707 |
if model_args.decoder_model_name_or_path:
|
| 708 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 709 |
+
model_args.decoder_model_name_or_path,
|
| 710 |
+
cache_dir=model_args.cache_dir,
|
| 711 |
+
use_fast=model_args.use_fast_tokenizer,
|
| 712 |
)
|
| 713 |
else:
|
| 714 |
raise ValueError(
|
|
|
|
| 763 |
for image_file in examples[image_column]:
|
| 764 |
try:
|
| 765 |
image = Image.open(image_file)
|
| 766 |
+
feature_extractor(images=image, return_tensors="np")
|
| 767 |
bools.append(True)
|
| 768 |
+
except Exception:
|
| 769 |
bools.append(False)
|
| 770 |
|
| 771 |
return bools
|
|
|
|
| 776 |
|
| 777 |
captions = []
|
| 778 |
for caption in examples[caption_column]:
|
| 779 |
+
captions.append(caption.lower() + " " + tokenizer.eos_token)
|
| 780 |
|
| 781 |
targets = captions
|
| 782 |
|
|
|
|
| 819 |
img = Image.open(image_file)
|
| 820 |
images.append(img)
|
| 821 |
to_keep.append(True)
|
| 822 |
+
except Exception:
|
| 823 |
to_keep.append(False)
|
| 824 |
|
| 825 |
for k, v in examples.items():
|
|
|
|
| 855 |
),
|
| 856 |
dtype="float32",
|
| 857 |
),
|
| 858 |
+
"labels": datasets.Sequence(feature=datasets.Value(dtype="int32", id=None), length=-1, id=None),
|
| 859 |
+
"decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype="int32", id=None), length=-1, id=None),
|
| 860 |
+
"decoder_attention_mask": datasets.Sequence(
|
| 861 |
+
feature=datasets.Value(dtype="int32", id=None), length=-1, id=None
|
| 862 |
+
),
|
| 863 |
}
|
| 864 |
)
|
| 865 |
|
|
|
|
| 935 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
| 936 |
# instead here.)
|
| 937 |
if not run_feat_ext_at_beginning:
|
| 938 |
+
predict_dataset = predict_dataset.filter(
|
| 939 |
+
filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers
|
| 940 |
+
)
|
| 941 |
predict_dataset = predict_dataset.map(
|
| 942 |
function=function_kwarg,
|
| 943 |
batched=True,
|
|
|
|
| 958 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 959 |
|
| 960 |
if training_args.block_size % train_batch_size > 0:
|
| 961 |
+
raise ValueError(
|
| 962 |
+
f"`training_args.block_size` needs to be a multiple of the global batch size. Got {training_args.block_size} and {train_batch_size} instead."
|
| 963 |
+
)
|
| 964 |
|
| 965 |
if training_args.do_train:
|
| 966 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
|
|
| 981 |
test_steps = num_test_examples // eval_batch_size
|
| 982 |
|
| 983 |
def blockwise_data_loader(
|
| 984 |
+
rng: jax.random.PRNGKey,
|
| 985 |
+
ds: Dataset,
|
| 986 |
+
block_size: int,
|
| 987 |
+
batch_size: int,
|
| 988 |
+
shuffle: bool = False,
|
| 989 |
+
keep_in_memory: bool = False,
|
| 990 |
+
split: str = "",
|
| 991 |
):
|
| 992 |
"""
|
| 993 |
Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
|
|
|
|
| 1195 |
|
| 1196 |
def generate_step(params, batch):
|
| 1197 |
model.params = params
|
| 1198 |
+
output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
|
| 1199 |
return output_ids.sequences
|
| 1200 |
|
| 1201 |
# Create parallel version of the train and eval step
|
|
|
|
| 1242 |
if training_args.push_to_hub:
|
| 1243 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
| 1244 |
|
| 1245 |
+
def evaluation_loop(
|
| 1246 |
+
rng: jax.random.PRNGKey,
|
| 1247 |
+
dataset: Dataset,
|
| 1248 |
+
metric_key_prefix: str = "eval",
|
| 1249 |
+
ckpt_dir: str = "",
|
| 1250 |
+
is_prediction=False,
|
| 1251 |
+
):
|
| 1252 |
|
| 1253 |
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
| 1254 |
|
|
|
|
| 1266 |
split="prediction" if is_prediction else "validation",
|
| 1267 |
)
|
| 1268 |
steps = len(dataset) // eval_batch_size
|
| 1269 |
+
for _ in tqdm(
|
| 1270 |
+
range(steps), desc=f"{'Predicting' if is_prediction else 'Evaluating'}...", position=2, leave=False
|
| 1271 |
+
):
|
| 1272 |
# Model forward
|
| 1273 |
batch = next(batches)
|
| 1274 |
_labels = batch.get("labels", None)
|
|
|
|
| 1298 |
if labels:
|
| 1299 |
rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
|
| 1300 |
metrics.update(rouge_metrics)
|
| 1301 |
+
rouge_desc = " ".join(
|
| 1302 |
+
[
|
| 1303 |
+
f"{'Predict' if is_prediction else 'Eval'} {key}: {value} |"
|
| 1304 |
+
for key, value in rouge_metrics.items()
|
| 1305 |
+
]
|
| 1306 |
+
)
|
| 1307 |
for pred, label in zip(decoded_preds, decoded_labels):
|
| 1308 |
pred = pred.replace("\n", " ")
|
| 1309 |
label = label.replace("\n", " ")
|
|
|
|
| 1336 |
|
| 1337 |
# Save metrics (only for the evaluation/prediction being done along with training)
|
| 1338 |
if has_tensorboard and training_args.do_train:
|
| 1339 |
+
write_metric(
|
| 1340 |
+
summary_writer, metrics, train_time=None, step=cur_step, metric_key_prefix=metric_key_prefix
|
| 1341 |
+
)
|
| 1342 |
|
| 1343 |
# save final metrics in json
|
| 1344 |
+
metrics = {
|
| 1345 |
+
f"{metric_key_prefix}_{metric_name}": round(value.item(), 6)
|
| 1346 |
+
for metric_name, value in metrics.items()
|
| 1347 |
+
}
|
| 1348 |
_path = os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_results.json")
|
| 1349 |
with open(_path, "w") as f:
|
| 1350 |
json.dump(metrics, f, indent=4, sort_keys=True)
|
| 1351 |
|
| 1352 |
# Update report
|
| 1353 |
+
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
| 1354 |
+
fp.write(desc + "\n")
|
| 1355 |
|
| 1356 |
# Save generations
|
| 1357 |
if generations:
|
| 1358 |
+
with open(
|
| 1359 |
+
os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_generation.json"),
|
| 1360 |
+
"w",
|
| 1361 |
+
encoding="UTF-8",
|
| 1362 |
+
) as fp:
|
| 1363 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
| 1364 |
|
| 1365 |
def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
|
| 1366 |
+
evaluation_loop(rng, dataset, metric_key_prefix="eval", ckpt_dir=ckpt_dir)
|
| 1367 |
|
| 1368 |
def predict(rng: jax.random.PRNGKey, dataset: Dataset):
|
| 1369 |
+
evaluation_loop(rng, dataset, metric_key_prefix="test", is_prediction=True)
|
| 1370 |
|
| 1371 |
input_rng = None
|
| 1372 |
|
|
|
|
| 1392 |
batch_size=train_batch_size,
|
| 1393 |
keep_in_memory=True,
|
| 1394 |
shuffle=True,
|
| 1395 |
+
split="train",
|
| 1396 |
)
|
| 1397 |
|
| 1398 |
# train
|
|
|
|
| 1416 |
|
| 1417 |
logger.info(desc)
|
| 1418 |
|
| 1419 |
+
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
| 1420 |
+
fp.write(desc + "\n")
|
| 1421 |
|
| 1422 |
# Save metrics
|
| 1423 |
if has_tensorboard and jax.process_index() == 0:
|
| 1424 |
+
write_metric(
|
| 1425 |
+
summary_writer,
|
| 1426 |
+
train_metrics,
|
| 1427 |
+
train_time=train_time,
|
| 1428 |
+
step=cur_step,
|
| 1429 |
+
metric_key_prefix="train",
|
| 1430 |
+
)
|
| 1431 |
|
| 1432 |
# ======================== Evaluating (inside an epoch) ==============================
|
| 1433 |
|
| 1434 |
+
if (
|
| 1435 |
+
training_args.do_eval
|
| 1436 |
+
and (training_args.eval_steps is not None and training_args.eval_steps > 0)
|
| 1437 |
+
and cur_step % training_args.eval_steps == 0
|
| 1438 |
+
):
|
| 1439 |
ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
|
| 1440 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
|
| 1441 |
evaluate(input_rng, eval_dataset, ckpt_dir)
|
|
|
|
| 1448 |
|
| 1449 |
logger.info(desc)
|
| 1450 |
|
| 1451 |
+
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
| 1452 |
+
fp.write(desc + "\n")
|
| 1453 |
|
| 1454 |
# Save metrics
|
| 1455 |
if has_tensorboard and jax.process_index() == 0:
|
| 1456 |
+
write_metric(
|
| 1457 |
+
summary_writer, train_metrics, train_time=train_time, step=cur_step, metric_key_prefix="train"
|
| 1458 |
+
)
|
| 1459 |
|
| 1460 |
# ======================== Evaluating (after each epoch) ==============================
|
| 1461 |
|