|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A few things commonly used across A LOT of config files."""
|
|
|
|
import string
|
|
|
|
import ml_collections as mlc
|
|
|
|
|
|
def input_for_quicktest(config_input, quicktest):
|
|
if quicktest:
|
|
config_input.batch_size = 8
|
|
config_input.shuffle_buffer_size = 10
|
|
config_input.cache_raw = False
|
|
|
|
|
|
def parse_arg(arg, lazy=False, **spec):
|
|
"""Makes ConfigDict's get_config single-string argument more usable.
|
|
|
|
Example use in the config file:
|
|
|
|
import big_vision.configs.common as bvcc
|
|
def get_config(arg):
|
|
arg = bvcc.parse_arg(arg,
|
|
res=(224, int),
|
|
runlocal=False,
|
|
schedule='short',
|
|
)
|
|
|
|
# ...
|
|
|
|
config.shuffle_buffer = 250_000 if not arg.runlocal else 50
|
|
|
|
Ways that values can be passed when launching:
|
|
|
|
--config amazing.py:runlocal,schedule=long,res=128
|
|
--config amazing.py:res=128
|
|
--config amazing.py:runlocal # A boolean needs no value for "true".
|
|
--config amazing.py:runlocal=False # Explicit false boolean.
|
|
--config amazing.py:128 # The first spec entry may be passed unnamed alone.
|
|
|
|
Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
|
|
'false', '' to False).
|
|
|
|
Args:
|
|
arg: the string argument that's passed to get_config.
|
|
lazy: allow lazy parsing of arguments, which are not in spec. For these,
|
|
the type is auto-extracted in dependence of most complex possible type.
|
|
**spec: the name and default values of the expected options.
|
|
If the value is a tuple, the value's first element is the default value,
|
|
and the second element is a function called to convert the string.
|
|
Otherwise the type is automatically extracted from the default value.
|
|
|
|
Returns:
|
|
ConfigDict object with extracted type-converted values.
|
|
"""
|
|
|
|
arg = arg or ''
|
|
spec = {k: get_type_with_default(v) for k, v in spec.items()}
|
|
|
|
result = mlc.ConfigDict(type_safe=False)
|
|
|
|
|
|
if arg and ',' not in arg and '=' not in arg:
|
|
|
|
|
|
if arg in spec or not spec:
|
|
arg = f'{arg}=True'
|
|
|
|
else:
|
|
arg = f'{list(spec.keys())[0]}={arg}'
|
|
|
|
|
|
|
|
raw_kv = {raw_arg.split('=')[0]:
|
|
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
|
|
for raw_arg in arg.split(',') if raw_arg}
|
|
|
|
|
|
for name, (default, type_fn) in spec.items():
|
|
val = raw_kv.pop(name, None)
|
|
result[name] = type_fn(val) if val is not None else default
|
|
|
|
if raw_kv:
|
|
if lazy:
|
|
for k, v in raw_kv.items():
|
|
result[k] = autotype(v)
|
|
else:
|
|
raise ValueError(f'Unhandled config args remain: {raw_kv}')
|
|
|
|
return result
|
|
|
|
|
|
def get_type_with_default(v):
|
|
"""Returns (v, string_to_v_type) with lenient bool parsing."""
|
|
|
|
if isinstance(v, bool):
|
|
def strict_bool(x):
|
|
assert x.lower() in {'true', 'false', ''}
|
|
return x.lower() == 'true'
|
|
return (v, strict_bool)
|
|
|
|
if isinstance(v, (tuple, list)):
|
|
assert len(v) == 2 and isinstance(v[1], type), (
|
|
'List or tuple types are currently not supported because we use `,` as'
|
|
' dumb delimiter. Contributions (probably using ast) welcome. You can'
|
|
' unblock by using a string with eval(s.replace(";", ",")) or similar')
|
|
return (v[0], v[1])
|
|
|
|
return (v, type(v))
|
|
|
|
|
|
def autotype(x):
|
|
"""Auto-converts string to bool/int/float if possible."""
|
|
assert isinstance(x, str)
|
|
if x.lower() in {'true', 'false'}:
|
|
return x.lower() == 'true'
|
|
try:
|
|
return int(x)
|
|
except ValueError:
|
|
try:
|
|
return float(x)
|
|
except ValueError:
|
|
return x
|
|
|
|
|
|
def pack_arg(**kw):
|
|
"""Packs key-word args as a string to be parsed by `parse_arg()`."""
|
|
for v in kw.values():
|
|
assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
|
|
return ','.join([f'{k}={v}' for k, v in kw.items()])
|
|
|
|
|
|
def arg(**kw):
|
|
"""Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
|
|
return {'config_arg': pack_arg(**kw), **kw}
|
|
|
|
|
|
def _get_field_ref(config_dict, field_name):
|
|
path = field_name.split('.')
|
|
for field in path[:-1]:
|
|
config_dict = getattr(config_dict, field)
|
|
return config_dict.get_ref(path[-1])
|
|
|
|
|
|
def format_str(format_string, config):
|
|
"""Format string with reference fields from config.
|
|
|
|
This makes it easy to build preprocess strings that contain references to
|
|
fields tha are edited after. E.g.:
|
|
|
|
```
|
|
config = mlc.ConficDict()
|
|
config.res = (256, 256)
|
|
config.pp = bvcc.format_str('resize({res})', config)
|
|
...
|
|
# if config.res is modified (e.g. via sweeps) it will propagate to pp field:
|
|
config.res = (512, 512)
|
|
assert config.pp == 'resize((512, 512))'
|
|
```
|
|
|
|
Args:
|
|
format_string: string to format with references.
|
|
config: ConfigDict to get references to format the string.
|
|
|
|
Returns:
|
|
A reference field which renders a string using references to config fields.
|
|
"""
|
|
output = ''
|
|
parts = string.Formatter().parse(format_string)
|
|
for (literal_text, field_name, format_spec, conversion) in parts:
|
|
assert not format_spec and not conversion
|
|
output += literal_text
|
|
if field_name:
|
|
output += _get_field_ref(config, field_name).to_str()
|
|
return output
|
|
|