ydshieh commited on
Commit
0fc714d
·
1 Parent(s): 7245cb4

use custom TrainingArguments

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +70 -3
run_image_captioning_flax.py CHANGED
@@ -52,7 +52,7 @@ from transformers import (
52
  AutoFeatureExtractor,
53
  AutoTokenizer,
54
  HfArgumentParser,
55
- TrainingArguments,
56
  is_tensorboard_available,
57
  FlaxAutoModelForVision2Seq,
58
  )
@@ -92,6 +92,72 @@ def shift_tokens_right(input_ids: np.ndarray, pad_token_id: int, decoder_start_t
92
  return shifted_input_ids
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @dataclass
96
  class CustomTrainingArguments(TrainingArguments):
97
 
@@ -1229,7 +1295,7 @@ def main():
1229
 
1230
  # ======================== Evaluating ==============================
1231
 
1232
- if training_args.do_eval and ((training_args.eval_steps is not None and cur_step % training_args.eval_steps) or cur_step % steps_per_epoch == 0):
1233
  run_eval_or_test(input_rng, eval_dataset, name="valid", is_inside_training=True)
1234
 
1235
  # ======================== Prediction loop ==============================
@@ -1238,8 +1304,9 @@ def main():
1238
  if training_args.do_predict and training_args.do_predict_during_training and training_args.do_predict_after_evaluation:
1239
  run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
1240
 
1241
- # ======================== Save ==============================
1242
 
 
1243
  save_results(epoch + 1, cur_step)
1244
 
1245
  # run prediction after each epoch (if not done during training)
 
52
  AutoFeatureExtractor,
53
  AutoTokenizer,
54
  HfArgumentParser,
55
+ # TrainingArguments,
56
  is_tensorboard_available,
57
  FlaxAutoModelForVision2Seq,
58
  )
 
92
  return shifted_input_ids
93
 
94
 
95
+ @dataclass
96
+ class TrainingArguments:
97
+ output_dir: str = field(
98
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
99
+ )
100
+ overwrite_output_dir: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": (
104
+ "Overwrite the content of the output directory. "
105
+ "Use this to continue training if output_dir points to a checkpoint directory."
106
+ )
107
+ },
108
+ )
109
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
110
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
111
+ do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
112
+ per_device_train_batch_size: int = field(
113
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
114
+ )
115
+ per_device_eval_batch_size: int = field(
116
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
117
+ )
118
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
119
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
120
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
121
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
122
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
123
+ label_smoothing_factor: float = field(
124
+ default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
125
+ )
126
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
127
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
128
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
129
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
130
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
131
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
132
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
133
+ push_to_hub: bool = field(
134
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
135
+ )
136
+ hub_model_id: str = field(
137
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
138
+ )
139
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
140
+
141
+ def __post_init__(self):
142
+ if self.output_dir is not None:
143
+ self.output_dir = os.path.expanduser(self.output_dir)
144
+
145
+ def to_dict(self):
146
+ """
147
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
148
+ the token values by removing their value.
149
+ """
150
+ d = asdict(self)
151
+ for k, v in d.items():
152
+ if isinstance(v, Enum):
153
+ d[k] = v.value
154
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
155
+ d[k] = [x.value for x in v]
156
+ if k.endswith("_token"):
157
+ d[k] = f"<{k.upper()}>"
158
+ return d
159
+
160
+
161
  @dataclass
162
  class CustomTrainingArguments(TrainingArguments):
163
 
 
1295
 
1296
  # ======================== Evaluating ==============================
1297
 
1298
+ if training_args.do_eval and ((training_args.eval_steps is not None and cur_step % training_args.eval_steps == 0) or cur_step % steps_per_epoch == 0):
1299
  run_eval_or_test(input_rng, eval_dataset, name="valid", is_inside_training=True)
1300
 
1301
  # ======================== Prediction loop ==============================
 
1304
  if training_args.do_predict and training_args.do_predict_during_training and training_args.do_predict_after_evaluation:
1305
  run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
1306
 
1307
+ # ======================== Save ==============================
1308
 
1309
+ if cur_step % training_args.save_steps == 0:
1310
  save_results(epoch + 1, cur_step)
1311
 
1312
  # run prediction after each epoch (if not done during training)