Fix some bugs
Browse files
config.json
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"architectures": [
|
| 3 |
-
"T5ForConditionalGeneration"
|
| 4 |
-
],
|
| 5 |
-
"d_ff": 3072,
|
| 6 |
-
"d_kv": 64,
|
| 7 |
-
"d_model": 768,
|
| 8 |
-
"decoder_start_token_id": 0,
|
| 9 |
-
"dropout_rate": 0.1,
|
| 10 |
-
"eos_token_id": 1,
|
| 11 |
-
"feed_forward_proj": "relu",
|
| 12 |
-
"gradient_checkpointing": false,
|
| 13 |
-
"initializer_factor": 1.0,
|
| 14 |
-
"is_encoder_decoder": true,
|
| 15 |
-
"layer_norm_epsilon": 1e-06,
|
| 16 |
-
"model_type": "t5",
|
| 17 |
-
"n_positions": 512,
|
| 18 |
-
"num_decoder_layers": 12,
|
| 19 |
-
"num_heads": 12,
|
| 20 |
-
"num_layers": 12,
|
| 21 |
-
"output_past": true,
|
| 22 |
-
"pad_token_id": 0,
|
| 23 |
-
"relative_attention_num_buckets": 32,
|
| 24 |
-
"task_specific_params": {
|
| 25 |
-
"summarization": {
|
| 26 |
-
"early_stopping": true,
|
| 27 |
-
"length_penalty": 2.0,
|
| 28 |
-
"max_length": 200,
|
| 29 |
-
"min_length": 30,
|
| 30 |
-
"no_repeat_ngram_size": 3,
|
| 31 |
-
"num_beams": 4,
|
| 32 |
-
"prefix": "summarize: "
|
| 33 |
-
},
|
| 34 |
-
"translation_en_to_de": {
|
| 35 |
-
"early_stopping": true,
|
| 36 |
-
"max_length": 300,
|
| 37 |
-
"num_beams": 4,
|
| 38 |
-
"prefix": "translate English to German: "
|
| 39 |
-
},
|
| 40 |
-
"translation_en_to_fr": {
|
| 41 |
-
"early_stopping": true,
|
| 42 |
-
"max_length": 300,
|
| 43 |
-
"num_beams": 4,
|
| 44 |
-
"prefix": "translate English to French: "
|
| 45 |
-
},
|
| 46 |
-
"translation_en_to_ro": {
|
| 47 |
-
"early_stopping": true,
|
| 48 |
-
"max_length": 300,
|
| 49 |
-
"num_beams": 4,
|
| 50 |
-
"prefix": "translate English to Romanian: "
|
| 51 |
-
}
|
| 52 |
-
},
|
| 53 |
-
"transformers_version": "4.9.0.dev0",
|
| 54 |
-
"use_cache": true,
|
| 55 |
-
"vocab_size": 32128
|
| 56 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
events.out.tfevents.1625682591.t1v-n-a0c138ef-w-0.124617.3.v2
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5c069e81c193f5ba7a9c8cff114c5522e13d8efd16e2e8c055c880bf5010f334
|
| 3 |
-
size 736165
|
|
|
|
|
|
|
|
|
|
|
|
flax_model.msgpack
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:12aea1d6f15b37764f5615dcb6d6bc6cc56e7d74cd3ce88cdd0469817b5a9c29
|
| 3 |
-
size 891625348
|
|
|
|
|
|
|
|
|
|
|
|
src/{preparaing_recipe_nlg_dataset.py → create_dataset.py}
RENAMED
|
@@ -114,6 +114,7 @@ def main():
|
|
| 114 |
|
| 115 |
return {
|
| 116 |
"inputs": ner,
|
|
|
|
| 117 |
"targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
|
| 118 |
}
|
| 119 |
|
|
|
|
| 114 |
|
| 115 |
return {
|
| 116 |
"inputs": ner,
|
| 117 |
+
# "targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
|
| 118 |
"targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
|
| 119 |
}
|
| 120 |
|
src/run.sh
CHANGED
|
@@ -5,6 +5,7 @@ export LANG=C.UTF-8
|
|
| 5 |
|
| 6 |
export OUTPUT_DIR=/home/m3hrdadfi/code/t5-recipe-generation
|
| 7 |
export MODEL_NAME_OR_PATH=t5-base
|
|
|
|
| 8 |
export NUM_BEAMS=3
|
| 9 |
|
| 10 |
export TRAIN_FILE=/home/m3hrdadfi/code/data/train.csv
|
|
|
|
| 5 |
|
| 6 |
export OUTPUT_DIR=/home/m3hrdadfi/code/t5-recipe-generation
|
| 7 |
export MODEL_NAME_OR_PATH=t5-base
|
| 8 |
+
# export MODEL_NAME_OR_PATH=flax-community/t5-recipe-generation
|
| 9 |
export NUM_BEAMS=3
|
| 10 |
|
| 11 |
export TRAIN_FILE=/home/m3hrdadfi/code/data/train.csv
|
src/run_recipe_nlg_flax.py
CHANGED
|
@@ -21,6 +21,7 @@ Fine-tuning the library models for recipe-generation.
|
|
| 21 |
import logging
|
| 22 |
import os
|
| 23 |
import random
|
|
|
|
| 24 |
import sys
|
| 25 |
import time
|
| 26 |
from dataclasses import dataclass, field
|
|
@@ -375,7 +376,7 @@ def main():
|
|
| 375 |
data_files["test"] = data_args.test_file
|
| 376 |
extension = data_args.test_file.split(".")[-1]
|
| 377 |
|
| 378 |
-
|
| 379 |
dataset = load_dataset(
|
| 380 |
extension,
|
| 381 |
data_files=data_files,
|
|
@@ -551,10 +552,30 @@ def main():
|
|
| 551 |
bleu = load_metric("sacrebleu")
|
| 552 |
wer = load_metric("wer")
|
| 553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
def postprocess_text(preds, labels):
|
| 555 |
-
preds = [pred.strip() for pred in preds]
|
| 556 |
-
labels_bleu = [[label.strip()] for label in labels]
|
| 557 |
-
labels_wer = [label.strip() for label in labels]
|
| 558 |
|
| 559 |
return preds, [labels_bleu, labels_wer]
|
| 560 |
|
|
@@ -846,11 +867,6 @@ def main():
|
|
| 846 |
push_to_hub=training_args.push_to_hub,
|
| 847 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 848 |
)
|
| 849 |
-
tokenizer.save_pretrained(
|
| 850 |
-
training_args.output_dir,
|
| 851 |
-
push_to_hub=training_args.push_to_hub,
|
| 852 |
-
commit_message=f"Saving tokenizer step {cur_step}",
|
| 853 |
-
)
|
| 854 |
|
| 855 |
|
| 856 |
if __name__ == "__main__":
|
|
|
|
| 21 |
import logging
|
| 22 |
import os
|
| 23 |
import random
|
| 24 |
+
import re
|
| 25 |
import sys
|
| 26 |
import time
|
| 27 |
from dataclasses import dataclass, field
|
|
|
|
| 376 |
data_files["test"] = data_args.test_file
|
| 377 |
extension = data_args.test_file.split(".")[-1]
|
| 378 |
|
| 379 |
+
logger.info(data_files)
|
| 380 |
dataset = load_dataset(
|
| 381 |
extension,
|
| 382 |
data_files=data_files,
|
|
|
|
| 552 |
bleu = load_metric("sacrebleu")
|
| 553 |
wer = load_metric("wer")
|
| 554 |
|
| 555 |
+
def skip_special_tokens_text(text):
|
| 556 |
+
new_text = []
|
| 557 |
+
for word in text.split():
|
| 558 |
+
word = word.strip()
|
| 559 |
+
if word:
|
| 560 |
+
if word not in special_tokens:
|
| 561 |
+
new_text.append(word)
|
| 562 |
+
|
| 563 |
+
return " ".join(new_text)
|
| 564 |
+
|
| 565 |
+
def skip_special_tokens_texts(texts):
|
| 566 |
+
if isinstance(texts, list):
|
| 567 |
+
new_texts = [skip_special_tokens_text(text) for text in texts]
|
| 568 |
+
elif isinstance(texts, str):
|
| 569 |
+
new_texts = skip_special_tokens_text(texts)
|
| 570 |
+
else:
|
| 571 |
+
new_texts = []
|
| 572 |
+
|
| 573 |
+
return new_texts
|
| 574 |
+
|
| 575 |
def postprocess_text(preds, labels):
|
| 576 |
+
preds = [skip_special_tokens_texts(pred.strip()) for pred in preds]
|
| 577 |
+
labels_bleu = [[skip_special_tokens_texts(label.strip())] for label in labels]
|
| 578 |
+
labels_wer = [skip_special_tokens_texts(label.strip()) for label in labels]
|
| 579 |
|
| 580 |
return preds, [labels_bleu, labels_wer]
|
| 581 |
|
|
|
|
| 867 |
push_to_hub=training_args.push_to_hub,
|
| 868 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 869 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
|
| 871 |
|
| 872 |
if __name__ == "__main__":
|