ndkhanh95's picture
Upload 304 files
fa1a600 verified
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A few things commonly used across A LOT of config files."""
import string
import ml_collections as mlc
def input_for_quicktest(config_input, quicktest):
if quicktest:
config_input.batch_size = 8
config_input.shuffle_buffer_size = 10
config_input.cache_raw = False
def parse_arg(arg, lazy=False, **spec):
"""Makes ConfigDict's get_config single-string argument more usable.
Example use in the config file:
import big_vision.configs.common as bvcc
def get_config(arg):
arg = bvcc.parse_arg(arg,
res=(224, int),
runlocal=False,
schedule='short',
)
# ...
config.shuffle_buffer = 250_000 if not arg.runlocal else 50
Ways that values can be passed when launching:
--config amazing.py:runlocal,schedule=long,res=128
--config amazing.py:res=128
--config amazing.py:runlocal # A boolean needs no value for "true".
--config amazing.py:runlocal=False # Explicit false boolean.
--config amazing.py:128 # The first spec entry may be passed unnamed alone.
Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
'false', '' to False).
Args:
arg: the string argument that's passed to get_config.
lazy: allow lazy parsing of arguments, which are not in spec. For these,
the type is auto-extracted in dependence of most complex possible type.
**spec: the name and default values of the expected options.
If the value is a tuple, the value's first element is the default value,
and the second element is a function called to convert the string.
Otherwise the type is automatically extracted from the default value.
Returns:
ConfigDict object with extracted type-converted values.
"""
# Normalize arg and spec layout.
arg = arg or '' # Normalize None to empty string
spec = {k: get_type_with_default(v) for k, v in spec.items()}
result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only.
# Expand convenience-cases for a single parameter without = sign.
if arg and ',' not in arg and '=' not in arg:
# (think :runlocal) If it's the name of sth in the spec (or there is no
# spec), it's that in bool.
if arg in spec or not spec:
arg = f'{arg}=True'
# Otherwise, it is the value for the first entry in the spec.
else:
arg = f'{list(spec.keys())[0]}={arg}'
# Yes, we rely on Py3.7 insertion order!
# Now, expand the `arg` string into a dict of keys and values:
raw_kv = {raw_arg.split('=')[0]:
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
for raw_arg in arg.split(',') if raw_arg}
# And go through the spec, using provided or default value for each:
for name, (default, type_fn) in spec.items():
val = raw_kv.pop(name, None)
result[name] = type_fn(val) if val is not None else default
if raw_kv:
if lazy: # Process args which are not in spec.
for k, v in raw_kv.items():
result[k] = autotype(v)
else:
raise ValueError(f'Unhandled config args remain: {raw_kv}')
return result
def get_type_with_default(v):
"""Returns (v, string_to_v_type) with lenient bool parsing."""
# For bool, do safe string conversion.
if isinstance(v, bool):
def strict_bool(x):
assert x.lower() in {'true', 'false', ''}
return x.lower() == 'true'
return (v, strict_bool)
# If already a (default, type) tuple, use that.
if isinstance(v, (tuple, list)):
assert len(v) == 2 and isinstance(v[1], type), (
'List or tuple types are currently not supported because we use `,` as'
' dumb delimiter. Contributions (probably using ast) welcome. You can'
' unblock by using a string with eval(s.replace(";", ",")) or similar')
return (v[0], v[1])
# Otherwise, derive the type from the default value.
return (v, type(v))
def autotype(x):
"""Auto-converts string to bool/int/float if possible."""
assert isinstance(x, str)
if x.lower() in {'true', 'false'}:
return x.lower() == 'true' # Returns as bool.
try:
return int(x) # Returns as int.
except ValueError:
try:
return float(x) # Returns as float.
except ValueError:
return x # Returns as str.
def pack_arg(**kw):
"""Packs key-word args as a string to be parsed by `parse_arg()`."""
for v in kw.values():
assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
return ','.join([f'{k}={v}' for k, v in kw.items()])
def arg(**kw):
"""Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
return {'config_arg': pack_arg(**kw), **kw}
def _get_field_ref(config_dict, field_name):
path = field_name.split('.')
for field in path[:-1]:
config_dict = getattr(config_dict, field)
return config_dict.get_ref(path[-1])
def format_str(format_string, config):
"""Format string with reference fields from config.
This makes it easy to build preprocess strings that contain references to
fields tha are edited after. E.g.:
```
config = mlc.ConficDict()
config.res = (256, 256)
config.pp = bvcc.format_str('resize({res})', config)
...
# if config.res is modified (e.g. via sweeps) it will propagate to pp field:
config.res = (512, 512)
assert config.pp == 'resize((512, 512))'
```
Args:
format_string: string to format with references.
config: ConfigDict to get references to format the string.
Returns:
A reference field which renders a string using references to config fields.
"""
output = ''
parts = string.Formatter().parse(format_string)
for (literal_text, field_name, format_spec, conversion) in parts:
assert not format_spec and not conversion
output += literal_text
if field_name:
output += _get_field_ref(config, field_name).to_str()
return output