|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Global Registry for big_vision pp ops.
|
|
|
|
Author: Joan Puigcerver (jpuigcerver@)
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import ast
|
|
import contextlib
|
|
import functools
|
|
|
|
|
|
def parse_name(string_to_parse):
|
|
"""Parses input to the registry's lookup function.
|
|
|
|
Args:
|
|
string_to_parse: can be either an arbitrary name or function call
|
|
(optionally with positional and keyword arguments).
|
|
e.g. "multiclass", "resnet50_v2(filters_factor=8)".
|
|
|
|
Returns:
|
|
A tuple of input name, argument tuple and a keyword argument dictionary.
|
|
Examples:
|
|
"multiclass" -> ("multiclass", (), {})
|
|
"resnet50_v2(9, filters_factor=4)" ->
|
|
("resnet50_v2", (9,), {"filters_factor": 4})
|
|
|
|
Author: Joan Puigcerver (jpuigcerver@)
|
|
"""
|
|
expr = ast.parse(string_to_parse, mode="eval").body
|
|
if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)):
|
|
raise ValueError(
|
|
"The given string should be a name or a call, but a {} was parsed from "
|
|
"the string {!r}".format(type(expr), string_to_parse))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(expr, ast.Name):
|
|
return string_to_parse, (), {}
|
|
elif isinstance(expr, ast.Attribute):
|
|
return string_to_parse, (), {}
|
|
|
|
def _get_func_name(expr):
|
|
if isinstance(expr, ast.Attribute):
|
|
return _get_func_name(expr.value) + "." + expr.attr
|
|
elif isinstance(expr, ast.Name):
|
|
return expr.id
|
|
else:
|
|
raise ValueError(
|
|
"Type {!r} is not supported in a function name, the string to parse "
|
|
"was {!r}".format(type(expr), string_to_parse))
|
|
|
|
def _get_func_args_and_kwargs(call):
|
|
args = tuple([ast.literal_eval(arg) for arg in call.args])
|
|
kwargs = {
|
|
kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords
|
|
}
|
|
return args, kwargs
|
|
|
|
func_name = _get_func_name(expr.func)
|
|
func_args, func_kwargs = _get_func_args_and_kwargs(expr)
|
|
|
|
return func_name, func_args, func_kwargs
|
|
|
|
|
|
class Registry(object):
|
|
"""Implements global Registry.
|
|
|
|
Authors: Joan Puigcerver (jpuigcerver@), Alexander Kolesnikov (akolesnikov@)
|
|
"""
|
|
|
|
_GLOBAL_REGISTRY = {}
|
|
|
|
@staticmethod
|
|
def global_registry():
|
|
return Registry._GLOBAL_REGISTRY
|
|
|
|
@staticmethod
|
|
def register(name, replace=False):
|
|
"""Creates a function that registers its input."""
|
|
|
|
def _register(item):
|
|
if name in Registry.global_registry() and not replace:
|
|
raise KeyError("The name {!r} was already registered.".format(name))
|
|
|
|
Registry.global_registry()[name] = item
|
|
return item
|
|
|
|
return _register
|
|
|
|
@staticmethod
|
|
def lookup(lookup_string, kwargs_extra=None):
|
|
"""Lookup a name in the registry."""
|
|
|
|
try:
|
|
name, args, kwargs = parse_name(lookup_string)
|
|
except ValueError as e:
|
|
raise ValueError(f"Error parsing:\n{lookup_string}") from e
|
|
if kwargs_extra:
|
|
kwargs.update(kwargs_extra)
|
|
item = Registry.global_registry()[name]
|
|
return functools.partial(item, *args, **kwargs)
|
|
|
|
@staticmethod
|
|
def knows(lookup_string):
|
|
try:
|
|
name, _, _ = parse_name(lookup_string)
|
|
except ValueError as e:
|
|
raise ValueError(f"Error parsing:\n{lookup_string}") from e
|
|
return name in Registry.global_registry()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def temporary_ops(**kw):
|
|
"""Registers specified pp ops for use in a `with` block.
|
|
|
|
Example use:
|
|
|
|
with pp_registry.remporary_ops(
|
|
pow=lambda alpha: lambda d: {k: v**alpha for k, v in d.items()}):
|
|
pp = pp_builder.get_preprocess_fn("pow(alpha=2.0)|pow(alpha=0.5)")
|
|
features = pp(features)
|
|
|
|
Args:
|
|
**kw: Names are preprocess string function names to be used to specify the
|
|
preprocess function. Values are functions that can be called with params
|
|
(e.g. the `alpha` param in above example) and return functions to be used
|
|
to transform features.
|
|
|
|
Yields:
|
|
A context manager to be used in a `with` statement.
|
|
"""
|
|
reg = Registry.global_registry()
|
|
kw = {f"preprocess_ops.{k}": v for k, v in kw.items()}
|
|
for k in kw:
|
|
assert k not in reg
|
|
for k, v in kw.items():
|
|
reg[k] = v
|
|
try:
|
|
yield
|
|
finally:
|
|
for k in kw:
|
|
del reg[k]
|
|
|