ydshieh
commited on
Commit
·
0fc714d
1
Parent(s):
7245cb4
use custom TrainingArguments
Browse files- run_image_captioning_flax.py +70 -3
run_image_captioning_flax.py
CHANGED
@@ -52,7 +52,7 @@ from transformers import (
|
|
52 |
AutoFeatureExtractor,
|
53 |
AutoTokenizer,
|
54 |
HfArgumentParser,
|
55 |
-
TrainingArguments,
|
56 |
is_tensorboard_available,
|
57 |
FlaxAutoModelForVision2Seq,
|
58 |
)
|
@@ -92,6 +92,72 @@ def shift_tokens_right(input_ids: np.ndarray, pad_token_id: int, decoder_start_t
|
|
92 |
return shifted_input_ids
|
93 |
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
@dataclass
|
96 |
class CustomTrainingArguments(TrainingArguments):
|
97 |
|
@@ -1229,7 +1295,7 @@ def main():
|
|
1229 |
|
1230 |
# ======================== Evaluating ==============================
|
1231 |
|
1232 |
-
if training_args.do_eval and ((training_args.eval_steps is not None and cur_step % training_args.eval_steps) or cur_step % steps_per_epoch == 0):
|
1233 |
run_eval_or_test(input_rng, eval_dataset, name="valid", is_inside_training=True)
|
1234 |
|
1235 |
# ======================== Prediction loop ==============================
|
@@ -1238,8 +1304,9 @@ def main():
|
|
1238 |
if training_args.do_predict and training_args.do_predict_during_training and training_args.do_predict_after_evaluation:
|
1239 |
run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
|
1240 |
|
1241 |
-
|
1242 |
|
|
|
1243 |
save_results(epoch + 1, cur_step)
|
1244 |
|
1245 |
# run prediction after each epoch (if not done during training)
|
|
|
52 |
AutoFeatureExtractor,
|
53 |
AutoTokenizer,
|
54 |
HfArgumentParser,
|
55 |
+
# TrainingArguments,
|
56 |
is_tensorboard_available,
|
57 |
FlaxAutoModelForVision2Seq,
|
58 |
)
|
|
|
92 |
return shifted_input_ids
|
93 |
|
94 |
|
95 |
+
@dataclass
|
96 |
+
class TrainingArguments:
|
97 |
+
output_dir: str = field(
|
98 |
+
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
99 |
+
)
|
100 |
+
overwrite_output_dir: bool = field(
|
101 |
+
default=False,
|
102 |
+
metadata={
|
103 |
+
"help": (
|
104 |
+
"Overwrite the content of the output directory. "
|
105 |
+
"Use this to continue training if output_dir points to a checkpoint directory."
|
106 |
+
)
|
107 |
+
},
|
108 |
+
)
|
109 |
+
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
110 |
+
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
111 |
+
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
112 |
+
per_device_train_batch_size: int = field(
|
113 |
+
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
114 |
+
)
|
115 |
+
per_device_eval_batch_size: int = field(
|
116 |
+
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
117 |
+
)
|
118 |
+
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
119 |
+
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
120 |
+
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
121 |
+
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
122 |
+
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
123 |
+
label_smoothing_factor: float = field(
|
124 |
+
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
125 |
+
)
|
126 |
+
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
127 |
+
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
128 |
+
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
129 |
+
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
130 |
+
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
131 |
+
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
132 |
+
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
133 |
+
push_to_hub: bool = field(
|
134 |
+
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
135 |
+
)
|
136 |
+
hub_model_id: str = field(
|
137 |
+
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
138 |
+
)
|
139 |
+
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
140 |
+
|
141 |
+
def __post_init__(self):
|
142 |
+
if self.output_dir is not None:
|
143 |
+
self.output_dir = os.path.expanduser(self.output_dir)
|
144 |
+
|
145 |
+
def to_dict(self):
|
146 |
+
"""
|
147 |
+
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
148 |
+
the token values by removing their value.
|
149 |
+
"""
|
150 |
+
d = asdict(self)
|
151 |
+
for k, v in d.items():
|
152 |
+
if isinstance(v, Enum):
|
153 |
+
d[k] = v.value
|
154 |
+
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
155 |
+
d[k] = [x.value for x in v]
|
156 |
+
if k.endswith("_token"):
|
157 |
+
d[k] = f"<{k.upper()}>"
|
158 |
+
return d
|
159 |
+
|
160 |
+
|
161 |
@dataclass
|
162 |
class CustomTrainingArguments(TrainingArguments):
|
163 |
|
|
|
1295 |
|
1296 |
# ======================== Evaluating ==============================
|
1297 |
|
1298 |
+
if training_args.do_eval and ((training_args.eval_steps is not None and cur_step % training_args.eval_steps == 0) or cur_step % steps_per_epoch == 0):
|
1299 |
run_eval_or_test(input_rng, eval_dataset, name="valid", is_inside_training=True)
|
1300 |
|
1301 |
# ======================== Prediction loop ==============================
|
|
|
1304 |
if training_args.do_predict and training_args.do_predict_during_training and training_args.do_predict_after_evaluation:
|
1305 |
run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
|
1306 |
|
1307 |
+
# ======================== Save ==============================
|
1308 |
|
1309 |
+
if cur_step % training_args.save_steps == 0:
|
1310 |
save_results(epoch + 1, cur_step)
|
1311 |
|
1312 |
# run prediction after each epoch (if not done during training)
|