File size: 6,612 Bytes
fa1a600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A few things commonly used across A LOT of config files."""

import string

import ml_collections as mlc


def input_for_quicktest(config_input, quicktest):
  if quicktest:
    config_input.batch_size = 8
    config_input.shuffle_buffer_size = 10
    config_input.cache_raw = False


def parse_arg(arg, lazy=False, **spec):
  """Makes ConfigDict's get_config single-string argument more usable.



  Example use in the config file:



    import big_vision.configs.common as bvcc

    def get_config(arg):

      arg = bvcc.parse_arg(arg,

          res=(224, int),

          runlocal=False,

          schedule='short',

      )



      # ...



      config.shuffle_buffer = 250_000 if not arg.runlocal else 50



  Ways that values can be passed when launching:



    --config amazing.py:runlocal,schedule=long,res=128

    --config amazing.py:res=128

    --config amazing.py:runlocal  # A boolean needs no value for "true".

    --config amazing.py:runlocal=False  # Explicit false boolean.

    --config amazing.py:128  # The first spec entry may be passed unnamed alone.



  Uses strict bool conversion (converting 'True', 'true' to True, and 'False',

    'false', '' to False).



  Args:

    arg: the string argument that's passed to get_config.

    lazy: allow lazy parsing of arguments, which are not in spec. For these,

      the type is auto-extracted in dependence of most complex possible type.

    **spec: the name and default values of the expected options.

      If the value is a tuple, the value's first element is the default value,

      and the second element is a function called to convert the string.

      Otherwise the type is automatically extracted from the default value.



  Returns:

    ConfigDict object with extracted type-converted values.

  """
  # Normalize arg and spec layout.
  arg = arg or ''  # Normalize None to empty string
  spec = {k: get_type_with_default(v) for k, v in spec.items()}

  result = mlc.ConfigDict(type_safe=False)  # For convenient dot-access only.

  # Expand convenience-cases for a single parameter without = sign.
  if arg and ',' not in arg and '=' not in arg:
    # (think :runlocal) If it's the name of sth in the spec (or there is no
    # spec), it's that in bool.
    if arg in spec or not spec:
      arg = f'{arg}=True'
    # Otherwise, it is the value for the first entry in the spec.
    else:
      arg = f'{list(spec.keys())[0]}={arg}'
      # Yes, we rely on Py3.7 insertion order!

  # Now, expand the `arg` string into a dict of keys and values:
  raw_kv = {raw_arg.split('=')[0]:
                raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
            for raw_arg in arg.split(',') if raw_arg}

  # And go through the spec, using provided or default value for each:
  for name, (default, type_fn) in spec.items():
    val = raw_kv.pop(name, None)
    result[name] = type_fn(val) if val is not None else default

  if raw_kv:
    if lazy:  # Process args which are not in spec.
      for k, v in raw_kv.items():
        result[k] = autotype(v)
    else:
      raise ValueError(f'Unhandled config args remain: {raw_kv}')

  return result


def get_type_with_default(v):
  """Returns (v, string_to_v_type) with lenient bool parsing."""
  # For bool, do safe string conversion.
  if isinstance(v, bool):
    def strict_bool(x):
      assert x.lower() in {'true', 'false', ''}
      return x.lower() == 'true'
    return (v, strict_bool)
  # If already a (default, type) tuple, use that.
  if isinstance(v, (tuple, list)):
    assert len(v) == 2 and isinstance(v[1], type), (
        'List or tuple types are currently not supported because we use `,` as'
        ' dumb delimiter. Contributions (probably using ast) welcome. You can'
        ' unblock by using a string with eval(s.replace(";", ",")) or similar')
    return (v[0], v[1])
  # Otherwise, derive the type from the default value.
  return (v, type(v))


def autotype(x):
  """Auto-converts string to bool/int/float if possible."""
  assert isinstance(x, str)
  if x.lower() in {'true', 'false'}:
    return x.lower() == 'true'  # Returns as bool.
  try:
    return int(x)  # Returns as int.
  except ValueError:
    try:
      return float(x)  # Returns as float.
    except ValueError:
      return x  # Returns as str.


def pack_arg(**kw):
  """Packs key-word args as a string to be parsed by `parse_arg()`."""
  for v in kw.values():
    assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
  return ','.join([f'{k}={v}' for k, v in kw.items()])


def arg(**kw):
  """Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
  return {'config_arg': pack_arg(**kw), **kw}


def _get_field_ref(config_dict, field_name):
  path = field_name.split('.')
  for field in path[:-1]:
    config_dict = getattr(config_dict, field)
  return config_dict.get_ref(path[-1])


def format_str(format_string, config):
  """Format string with reference fields from config.



  This makes it easy to build preprocess strings that contain references to

  fields tha are edited after. E.g.:



  ```

  config = mlc.ConficDict()

  config.res = (256, 256)

  config.pp = bvcc.format_str('resize({res})', config)

  ...

  # if config.res is modified (e.g. via sweeps) it will propagate to pp field:

  config.res = (512, 512)

  assert config.pp == 'resize((512, 512))'

  ```



  Args:

    format_string: string to format with references.

    config: ConfigDict to get references to format the string.



  Returns:

    A reference field which renders a string using references to config fields.

  """
  output = ''
  parts = string.Formatter().parse(format_string)
  for (literal_text, field_name, format_spec, conversion) in parts:
    assert not format_spec and not conversion
    output += literal_text
    if field_name:
      output += _get_field_ref(config, field_name).to_str()
  return output