ydshieh commited on
Commit
93a6e2b
·
1 Parent(s): 99c1534
Files changed (1) hide show
  1. run_image_captioning_flax.py +4 -14
run_image_captioning_flax.py CHANGED
@@ -22,8 +22,8 @@ import logging
22
  import os
23
  import sys
24
  import time
25
- from dataclasses import dataclass, field
26
- import datetime
27
  from functools import partial
28
  from pathlib import Path
29
  from typing import Callable, Optional
@@ -61,9 +61,6 @@ from transformers.file_utils import get_full_repo_name, is_offline_mode
61
 
62
  logger = logging.getLogger(__name__)
63
 
64
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
65
-
66
-
67
  try:
68
  nltk.data.find("tokenizers/punkt")
69
  except (LookupError, OSError):
@@ -115,6 +112,7 @@ class TrainingArguments:
115
  per_device_eval_batch_size: int = field(
116
  default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
117
  )
 
118
  learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
119
  weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
120
  adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
@@ -159,14 +157,6 @@ class TrainingArguments:
159
  return d
160
 
161
 
162
- @dataclass
163
- class CustomTrainingArguments(TrainingArguments):
164
-
165
- do_predict_during_training: bool = field(default=None, metadata={"help": "???"})
166
- do_predict_after_evaluation: bool = field(default=None, metadata={"help": "???"})
167
- block_size: int = field(default=None, metadata={"help": "???"})
168
-
169
-
170
  @dataclass
171
  class ModelArguments:
172
  """
@@ -417,7 +407,7 @@ def main():
417
  # or by passing the --help flag to this script.
418
  # We now keep distinct sets of args, for a cleaner separation of concerns.
419
 
420
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments))
421
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
422
  # If we pass only one argument to the script and it's the path to a json file,
423
  # let's parse it to get our arguments.
 
22
  import os
23
  import sys
24
  import time
25
+ from dataclasses import asdict, dataclass, field
26
+ from enum import Enum
27
  from functools import partial
28
  from pathlib import Path
29
  from typing import Callable, Optional
 
61
 
62
  logger = logging.getLogger(__name__)
63
 
 
 
 
64
  try:
65
  nltk.data.find("tokenizers/punkt")
66
  except (LookupError, OSError):
 
112
  per_device_eval_batch_size: int = field(
113
  default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
114
  )
115
+ block_size: int = field(default=None, metadata={"help": "???"})
116
  learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
117
  weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
118
  adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
 
157
  return d
158
 
159
 
 
 
 
 
 
 
 
 
160
  @dataclass
161
  class ModelArguments:
162
  """
 
407
  # or by passing the --help flag to this script.
408
  # We now keep distinct sets of args, for a cleaner separation of concerns.
409
 
410
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
411
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
412
  # If we pass only one argument to the script and it's the path to a json file,
413
  # let's parse it to get our arguments.