ydshieh
commited on
Commit
·
93a6e2b
1
Parent(s):
99c1534
fix
Browse files- run_image_captioning_flax.py +4 -14
run_image_captioning_flax.py
CHANGED
@@ -22,8 +22,8 @@ import logging
|
|
22 |
import os
|
23 |
import sys
|
24 |
import time
|
25 |
-
from dataclasses import dataclass, field
|
26 |
-
import
|
27 |
from functools import partial
|
28 |
from pathlib import Path
|
29 |
from typing import Callable, Optional
|
@@ -61,9 +61,6 @@ from transformers.file_utils import get_full_repo_name, is_offline_mode
|
|
61 |
|
62 |
logger = logging.getLogger(__name__)
|
63 |
|
64 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
65 |
-
|
66 |
-
|
67 |
try:
|
68 |
nltk.data.find("tokenizers/punkt")
|
69 |
except (LookupError, OSError):
|
@@ -115,6 +112,7 @@ class TrainingArguments:
|
|
115 |
per_device_eval_batch_size: int = field(
|
116 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
117 |
)
|
|
|
118 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
119 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
120 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
@@ -159,14 +157,6 @@ class TrainingArguments:
|
|
159 |
return d
|
160 |
|
161 |
|
162 |
-
@dataclass
|
163 |
-
class CustomTrainingArguments(TrainingArguments):
|
164 |
-
|
165 |
-
do_predict_during_training: bool = field(default=None, metadata={"help": "???"})
|
166 |
-
do_predict_after_evaluation: bool = field(default=None, metadata={"help": "???"})
|
167 |
-
block_size: int = field(default=None, metadata={"help": "???"})
|
168 |
-
|
169 |
-
|
170 |
@dataclass
|
171 |
class ModelArguments:
|
172 |
"""
|
@@ -417,7 +407,7 @@ def main():
|
|
417 |
# or by passing the --help flag to this script.
|
418 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
419 |
|
420 |
-
parser = HfArgumentParser((ModelArguments, DataTrainingArguments,
|
421 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
422 |
# If we pass only one argument to the script and it's the path to a json file,
|
423 |
# let's parse it to get our arguments.
|
|
|
22 |
import os
|
23 |
import sys
|
24 |
import time
|
25 |
+
from dataclasses import asdict, dataclass, field
|
26 |
+
from enum import Enum
|
27 |
from functools import partial
|
28 |
from pathlib import Path
|
29 |
from typing import Callable, Optional
|
|
|
61 |
|
62 |
logger = logging.getLogger(__name__)
|
63 |
|
|
|
|
|
|
|
64 |
try:
|
65 |
nltk.data.find("tokenizers/punkt")
|
66 |
except (LookupError, OSError):
|
|
|
112 |
per_device_eval_batch_size: int = field(
|
113 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
114 |
)
|
115 |
+
block_size: int = field(default=None, metadata={"help": "???"})
|
116 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
117 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
118 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
|
157 |
return d
|
158 |
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
@dataclass
|
161 |
class ModelArguments:
|
162 |
"""
|
|
|
407 |
# or by passing the --help flag to this script.
|
408 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
409 |
|
410 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
411 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
412 |
# If we pass only one argument to the script and it's the path to a json file,
|
413 |
# let's parse it to get our arguments.
|