File size: 6,612 Bytes
fa1a600 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A few things commonly used across A LOT of config files."""
import string
import ml_collections as mlc
def input_for_quicktest(config_input, quicktest):
if quicktest:
config_input.batch_size = 8
config_input.shuffle_buffer_size = 10
config_input.cache_raw = False
def parse_arg(arg, lazy=False, **spec):
"""Makes ConfigDict's get_config single-string argument more usable.
Example use in the config file:
import big_vision.configs.common as bvcc
def get_config(arg):
arg = bvcc.parse_arg(arg,
res=(224, int),
runlocal=False,
schedule='short',
)
# ...
config.shuffle_buffer = 250_000 if not arg.runlocal else 50
Ways that values can be passed when launching:
--config amazing.py:runlocal,schedule=long,res=128
--config amazing.py:res=128
--config amazing.py:runlocal # A boolean needs no value for "true".
--config amazing.py:runlocal=False # Explicit false boolean.
--config amazing.py:128 # The first spec entry may be passed unnamed alone.
Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
'false', '' to False).
Args:
arg: the string argument that's passed to get_config.
lazy: allow lazy parsing of arguments, which are not in spec. For these,
the type is auto-extracted in dependence of most complex possible type.
**spec: the name and default values of the expected options.
If the value is a tuple, the value's first element is the default value,
and the second element is a function called to convert the string.
Otherwise the type is automatically extracted from the default value.
Returns:
ConfigDict object with extracted type-converted values.
"""
# Normalize arg and spec layout.
arg = arg or '' # Normalize None to empty string
spec = {k: get_type_with_default(v) for k, v in spec.items()}
result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only.
# Expand convenience-cases for a single parameter without = sign.
if arg and ',' not in arg and '=' not in arg:
# (think :runlocal) If it's the name of sth in the spec (or there is no
# spec), it's that in bool.
if arg in spec or not spec:
arg = f'{arg}=True'
# Otherwise, it is the value for the first entry in the spec.
else:
arg = f'{list(spec.keys())[0]}={arg}'
# Yes, we rely on Py3.7 insertion order!
# Now, expand the `arg` string into a dict of keys and values:
raw_kv = {raw_arg.split('=')[0]:
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
for raw_arg in arg.split(',') if raw_arg}
# And go through the spec, using provided or default value for each:
for name, (default, type_fn) in spec.items():
val = raw_kv.pop(name, None)
result[name] = type_fn(val) if val is not None else default
if raw_kv:
if lazy: # Process args which are not in spec.
for k, v in raw_kv.items():
result[k] = autotype(v)
else:
raise ValueError(f'Unhandled config args remain: {raw_kv}')
return result
def get_type_with_default(v):
"""Returns (v, string_to_v_type) with lenient bool parsing."""
# For bool, do safe string conversion.
if isinstance(v, bool):
def strict_bool(x):
assert x.lower() in {'true', 'false', ''}
return x.lower() == 'true'
return (v, strict_bool)
# If already a (default, type) tuple, use that.
if isinstance(v, (tuple, list)):
assert len(v) == 2 and isinstance(v[1], type), (
'List or tuple types are currently not supported because we use `,` as'
' dumb delimiter. Contributions (probably using ast) welcome. You can'
' unblock by using a string with eval(s.replace(";", ",")) or similar')
return (v[0], v[1])
# Otherwise, derive the type from the default value.
return (v, type(v))
def autotype(x):
"""Auto-converts string to bool/int/float if possible."""
assert isinstance(x, str)
if x.lower() in {'true', 'false'}:
return x.lower() == 'true' # Returns as bool.
try:
return int(x) # Returns as int.
except ValueError:
try:
return float(x) # Returns as float.
except ValueError:
return x # Returns as str.
def pack_arg(**kw):
"""Packs key-word args as a string to be parsed by `parse_arg()`."""
for v in kw.values():
assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
return ','.join([f'{k}={v}' for k, v in kw.items()])
def arg(**kw):
"""Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
return {'config_arg': pack_arg(**kw), **kw}
def _get_field_ref(config_dict, field_name):
path = field_name.split('.')
for field in path[:-1]:
config_dict = getattr(config_dict, field)
return config_dict.get_ref(path[-1])
def format_str(format_string, config):
"""Format string with reference fields from config.
This makes it easy to build preprocess strings that contain references to
fields tha are edited after. E.g.:
```
config = mlc.ConficDict()
config.res = (256, 256)
config.pp = bvcc.format_str('resize({res})', config)
...
# if config.res is modified (e.g. via sweeps) it will propagate to pp field:
config.res = (512, 512)
assert config.pp == 'resize((512, 512))'
```
Args:
format_string: string to format with references.
config: ConfigDict to get references to format the string.
Returns:
A reference field which renders a string using references to config fields.
"""
output = ''
parts = string.Formatter().parse(format_string)
for (literal_text, field_name, format_spec, conversion) in parts:
assert not format_spec and not conversion
output += literal_text
if field_name:
output += _get_field_ref(config, field_name).to_str()
return output
|