ydshieh
commited on
Commit
·
a5be38e
1
Parent(s):
fd1b4a2
improve code
Browse files- run_image_captioning_flax.py +13 -9
run_image_captioning_flax.py
CHANGED
@@ -48,6 +48,7 @@ from huggingface_hub import Repository
|
|
48 |
from transformers import (
|
49 |
CONFIG_MAPPING,
|
50 |
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
|
|
51 |
AutoConfig,
|
52 |
AutoFeatureExtractor,
|
53 |
AutoTokenizer,
|
@@ -73,6 +74,7 @@ except (LookupError, OSError):
|
|
73 |
|
74 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
|
75 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
|
76 |
|
77 |
|
78 |
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
@@ -196,15 +198,15 @@ class ModelArguments:
|
|
196 |
)
|
197 |
model_type: Optional[str] = field(
|
198 |
default='vision-encoder-decoder',
|
199 |
-
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}
|
200 |
)
|
201 |
encoder_model_type: Optional[str] = field(
|
202 |
default=None,
|
203 |
-
metadata={"help": "If training from scratch, pass a encoder model type from the
|
204 |
)
|
205 |
decoder_model_type: Optional[str] = field(
|
206 |
default=None,
|
207 |
-
metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(
|
208 |
)
|
209 |
config_name: Optional[str] = field(
|
210 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
@@ -570,7 +572,7 @@ def main():
|
|
570 |
config.decoder.eos_token_id = eos_token_id
|
571 |
config.decoder.pad_token_id = pad_token_id
|
572 |
|
573 |
-
# Set `encoder-decoder` (top-level) specific config
|
574 |
config.decoder_start_token_id = decoder_start_token_id
|
575 |
config.bos_token_id = bos_token_id
|
576 |
config.eos_token_id = eos_token_id
|
@@ -630,7 +632,7 @@ def main():
|
|
630 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
631 |
)
|
632 |
|
633 |
-
# Set `encoder-decoder` (top-level) specific config
|
634 |
model.config.decoder_start_token_id = decoder_start_token_id
|
635 |
model.config.bos_token_id = bos_token_id
|
636 |
model.config.eos_token_id = eos_token_id
|
@@ -729,6 +731,7 @@ def main():
|
|
729 |
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
|
730 |
|
731 |
def filter_fn(examples):
|
|
|
732 |
|
733 |
bools = []
|
734 |
for image_file in examples[image_column]:
|
@@ -1163,7 +1166,8 @@ def main():
|
|
1163 |
if not os.path.isdir(os.path.join(training_args.output_dir)):
|
1164 |
os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
|
1165 |
|
1166 |
-
def save_ckpt(ckpt_dir: str, commit_msg: str =""):
|
|
|
1167 |
|
1168 |
# save checkpoint after each epoch and push checkpoint to the hub
|
1169 |
if jax.process_index() == 0:
|
@@ -1259,7 +1263,7 @@ def main():
|
|
1259 |
json.dump(metrics, f, indent=4, sort_keys=True)
|
1260 |
|
1261 |
# Update report
|
1262 |
-
with open(os.path.join(training_args.output_dir, '
|
1263 |
fp.write(desc + '\n')
|
1264 |
|
1265 |
# Save metrics (only for the evaluation/prediction being done along with training)
|
@@ -1325,7 +1329,7 @@ def main():
|
|
1325 |
|
1326 |
logger.info(desc)
|
1327 |
|
1328 |
-
with open(os.path.join(training_args.output_dir, '
|
1329 |
fp.write(desc + '\n')
|
1330 |
|
1331 |
# Save metrics
|
@@ -1347,7 +1351,7 @@ def main():
|
|
1347 |
|
1348 |
logger.info(desc)
|
1349 |
|
1350 |
-
with open(os.path.join(training_args.output_dir, '
|
1351 |
fp.write(desc + '\n')
|
1352 |
|
1353 |
# Save metrics
|
|
|
48 |
from transformers import (
|
49 |
CONFIG_MAPPING,
|
50 |
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
51 |
+
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
52 |
AutoConfig,
|
53 |
AutoFeatureExtractor,
|
54 |
AutoTokenizer,
|
|
|
74 |
|
75 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
|
76 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
77 |
+
DECODER_MODEL_TYPES = tuple(conf.model_type for conf in list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()))
|
78 |
|
79 |
|
80 |
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
|
|
198 |
)
|
199 |
model_type: Optional[str] = field(
|
200 |
default='vision-encoder-decoder',
|
201 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}
|
202 |
)
|
203 |
encoder_model_type: Optional[str] = field(
|
204 |
default=None,
|
205 |
+
metadata={"help": "If training from scratch, pass a encoder model type from the library"}
|
206 |
)
|
207 |
decoder_model_type: Optional[str] = field(
|
208 |
default=None,
|
209 |
+
metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(DECODER_MODEL_TYPES)}
|
210 |
)
|
211 |
config_name: Optional[str] = field(
|
212 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
|
572 |
config.decoder.eos_token_id = eos_token_id
|
573 |
config.decoder.pad_token_id = pad_token_id
|
574 |
|
575 |
+
# Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
|
576 |
config.decoder_start_token_id = decoder_start_token_id
|
577 |
config.bos_token_id = bos_token_id
|
578 |
config.eos_token_id = eos_token_id
|
|
|
632 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
633 |
)
|
634 |
|
635 |
+
# Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
|
636 |
model.config.decoder_start_token_id = decoder_start_token_id
|
637 |
model.config.bos_token_id = bos_token_id
|
638 |
model.config.eos_token_id = eos_token_id
|
|
|
731 |
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
|
732 |
|
733 |
def filter_fn(examples):
|
734 |
+
"""remove problematic images"""
|
735 |
|
736 |
bools = []
|
737 |
for image_file in examples[image_column]:
|
|
|
1166 |
if not os.path.isdir(os.path.join(training_args.output_dir)):
|
1167 |
os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
|
1168 |
|
1169 |
+
def save_ckpt(ckpt_dir: str, commit_msg: str = ""):
|
1170 |
+
"""save checkpoints and push to Hugging Face Hub if specified"""
|
1171 |
|
1172 |
# save checkpoint after each epoch and push checkpoint to the hub
|
1173 |
if jax.process_index() == 0:
|
|
|
1263 |
json.dump(metrics, f, indent=4, sort_keys=True)
|
1264 |
|
1265 |
# Update report
|
1266 |
+
with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
|
1267 |
fp.write(desc + '\n')
|
1268 |
|
1269 |
# Save metrics (only for the evaluation/prediction being done along with training)
|
|
|
1329 |
|
1330 |
logger.info(desc)
|
1331 |
|
1332 |
+
with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
|
1333 |
fp.write(desc + '\n')
|
1334 |
|
1335 |
# Save metrics
|
|
|
1351 |
|
1352 |
logger.info(desc)
|
1353 |
|
1354 |
+
with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
|
1355 |
fp.write(desc + '\n')
|
1356 |
|
1357 |
# Save metrics
|