ydshieh
commited on
Commit
·
ddad56c
1
Parent(s):
ea4daa2
update 11
Browse files
run_image_captioning_flax_reduced.py
CHANGED
|
@@ -507,12 +507,6 @@ def main():
|
|
| 507 |
decoder_config.is_decoder = True
|
| 508 |
decoder_config.add_cross_attention = True
|
| 509 |
|
| 510 |
-
# GPT2 only has bos/eos token but not decoder_start/pad token
|
| 511 |
-
if decoder_config.decoder_start_token_id is None:
|
| 512 |
-
decoder_config.decoder_start_token_id = decoder_config.bos_token_id
|
| 513 |
-
if decoder_config.pad_token_id is None:
|
| 514 |
-
decoder_config.pad_token_id = decoder_config.eos_token_id
|
| 515 |
-
|
| 516 |
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
| 517 |
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
|
| 518 |
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
|
|
@@ -523,10 +517,19 @@ def main():
|
|
| 523 |
encoder_dtype=getattr(jnp, model_args.dtype),
|
| 524 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
| 525 |
)
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
model.config.eos_token_id = decoder_config.eos_token_id
|
| 528 |
-
model.config.decoder_start_token_id =
|
| 529 |
-
model.config.pad_token_id =
|
| 530 |
|
| 531 |
if model_args.feature_extractor_name:
|
| 532 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
@@ -556,7 +559,7 @@ def main():
|
|
| 556 |
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
| 557 |
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 558 |
)
|
| 559 |
-
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.
|
| 560 |
|
| 561 |
# Preprocessing the datasets.
|
| 562 |
# We need to tokenize inputs and targets.
|
|
@@ -631,7 +634,7 @@ def main():
|
|
| 631 |
|
| 632 |
model_inputs["labels"] = labels["input_ids"]
|
| 633 |
decoder_input_ids = shift_tokens_right_fn(
|
| 634 |
-
labels["input_ids"], model.config.
|
| 635 |
)
|
| 636 |
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
| 637 |
|
|
|
|
| 507 |
decoder_config.is_decoder = True
|
| 508 |
decoder_config.add_cross_attention = True
|
| 509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
| 511 |
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
|
| 512 |
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
|
|
|
|
| 517 |
encoder_dtype=getattr(jnp, model_args.dtype),
|
| 518 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
| 519 |
)
|
| 520 |
+
|
| 521 |
+
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
|
| 522 |
+
decoder_start_token_id = decoder_config.decoder_start_token_id
|
| 523 |
+
pad_token_id = decoder_config.pad_token_id
|
| 524 |
+
if decoder_start_token_id is None:
|
| 525 |
+
decoder_config.pad_token_id = decoder_config.bos_token_id
|
| 526 |
+
if pad_token_id is None:
|
| 527 |
+
pad_token_id = decoder_config.pad_token_id
|
| 528 |
+
|
| 529 |
+
# This is necessary to make Flax's generate() work
|
| 530 |
model.config.eos_token_id = decoder_config.eos_token_id
|
| 531 |
+
model.config.decoder_start_token_id = decoder_start_token_id
|
| 532 |
+
model.config.pad_token_id = pad_token_id
|
| 533 |
|
| 534 |
if model_args.feature_extractor_name:
|
| 535 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
|
|
| 559 |
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
| 560 |
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 561 |
)
|
| 562 |
+
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)
|
| 563 |
|
| 564 |
# Preprocessing the datasets.
|
| 565 |
# We need to tokenize inputs and targets.
|
|
|
|
| 634 |
|
| 635 |
model_inputs["labels"] = labels["input_ids"]
|
| 636 |
decoder_input_ids = shift_tokens_right_fn(
|
| 637 |
+
labels["input_ids"], model.config.pad_token_id, model.config.decoder_start_token_id
|
| 638 |
)
|
| 639 |
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
| 640 |
|