|
|
|
import inspect |
|
import sys |
|
import types |
|
from collections import abc |
|
from collections.abc import Callable |
|
from contextlib import contextmanager |
|
from importlib import import_module |
|
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union |
|
from rich.console import Console |
|
from rich.table import Table |
|
|
|
from .default_scope import DefaultScope |
|
|
|
|
|
MODULE2PACKAGE = { |
|
'mmcls': 'mmcls', |
|
'mmdet': 'mmdet', |
|
'mmdet3d': 'mmdet3d', |
|
'mmseg': 'mmsegmentation', |
|
'mmaction': 'mmaction2', |
|
'mmtrack': 'mmtrack', |
|
'mmpose': 'mmpose', |
|
'mmedit': 'mmedit', |
|
'mmocr': 'mmocr', |
|
'mmgen': 'mmgen', |
|
'mmfewshot': 'mmfewshot', |
|
'mmrazor': 'mmrazor', |
|
'mmflow': 'mmflow', |
|
'mmhuman3d': 'mmhuman3d', |
|
'mmrotate': 'mmrotate', |
|
'mmselfsup': 'mmselfsup', |
|
'mmyolo': 'mmyolo', |
|
'mmpretrain': 'mmpretrain', |
|
'mmagic': 'mmagic', |
|
} |
|
|
|
class Registry: |
|
"""A registry to map strings to classes or functions. |
|
|
|
Registered object could be built from registry. Meanwhile, registered |
|
functions could be called from registry. |
|
|
|
Args: |
|
name (str): Registry name. |
|
build_func (callable, optional): A function to construct instance |
|
from Registry. :func:`build_from_cfg` is used if neither ``parent`` |
|
or ``build_func`` is specified. If ``parent`` is specified and |
|
``build_func`` is not given, ``build_func`` will be inherited |
|
from ``parent``. Defaults to None. |
|
parent (:obj:`Registry`, optional): Parent registry. The class |
|
registered in children registry could be built from parent. |
|
Defaults to None. |
|
scope (str, optional): The scope of registry. It is the key to search |
|
for children registry. If not specified, scope will be the name of |
|
the package where class is defined, e.g. mmdet, mmcls, mmseg. |
|
Defaults to None. |
|
locations (list): The locations to import the modules registered |
|
in this registry. Defaults to []. |
|
New in version 0.4.0. |
|
""" |
|
|
|
def __init__(self, |
|
name: str, |
|
build_func: Optional[Callable] = None, |
|
parent: Optional['Registry'] = None, |
|
scope: Optional[str] = None, |
|
locations: List = []): |
|
self._name = name |
|
self._module_dict: Dict[str, Type] = dict() |
|
self._children: Dict[str, 'Registry'] = dict() |
|
self._locations = locations |
|
self._imported = False |
|
|
|
if scope is not None: |
|
assert isinstance(scope, str) |
|
self._scope = scope |
|
else: |
|
self._scope = self.infer_scope() |
|
|
|
self.parent: Optional['Registry'] |
|
if parent is not None: |
|
assert isinstance(parent, Registry) |
|
parent._add_child(self) |
|
self.parent = parent |
|
else: |
|
self.parent = None |
|
|
|
self.build_func: Callable |
|
if build_func is None: |
|
if self.parent is not None: |
|
self.build_func = self.parent.build_func |
|
else: |
|
from ..utils.build_functions import build_from_cfg |
|
self.build_func = build_from_cfg |
|
else: |
|
self.build_func = build_func |
|
|
|
def __len__(self): |
|
return len(self._module_dict) |
|
|
|
def __contains__(self, key): |
|
return self.get(key) is not None |
|
|
|
def __repr__(self): |
|
table = Table(title=f'Registry of {self._name}') |
|
table.add_column('Names', justify='left', style='cyan') |
|
table.add_column('Objects', justify='left', style='green') |
|
|
|
for name, obj in sorted(self._module_dict.items()): |
|
table.add_row(name, str(obj)) |
|
|
|
console = Console() |
|
with console.capture() as capture: |
|
console.print(table, end='') |
|
|
|
return capture.get() |
|
|
|
@staticmethod |
|
def infer_scope() -> str: |
|
"""Infer the scope of registry. |
|
|
|
The name of the package where registry is defined will be returned. |
|
|
|
Returns: |
|
str: The inferred scope name. |
|
""" |
|
module = inspect.getmodule(sys._getframe(2)) |
|
if module is not None: |
|
filename = module.__name__ |
|
split_filename = filename.split('.') |
|
scope = split_filename[0] |
|
else: |
|
scope = 'mmengine' |
|
return scope |
|
|
|
@staticmethod |
|
def split_scope_key(key: str) -> Tuple[Optional[str], str]: |
|
"""Split scope and key. |
|
|
|
The first scope will be split from key. |
|
|
|
Return: |
|
tuple[str | None, str]: The former element is the first scope of |
|
the key, which can be ``None``. The latter is the remaining key. |
|
|
|
""" |
|
split_index = key.find('.') |
|
if split_index != -1: |
|
return key[:split_index], key[split_index + 1:] |
|
else: |
|
return None, key |
|
|
|
@property |
|
def name(self): |
|
return self._name |
|
|
|
@property |
|
def scope(self): |
|
return self._scope |
|
|
|
@property |
|
def module_dict(self): |
|
return self._module_dict |
|
|
|
@property |
|
def children(self): |
|
return self._children |
|
|
|
@property |
|
def root(self): |
|
return self._get_root_registry() |
|
|
|
@contextmanager |
|
def switch_scope_and_registry(self, scope: Optional[str]) -> Generator: |
|
"""Temporarily switch default scope to the target scope, and get the |
|
corresponding registry. |
|
|
|
If the registry of the corresponding scope exists, yield the |
|
registry, otherwise yield the current itself. |
|
|
|
Args: |
|
scope (str, optional): The target scope. |
|
""" |
|
|
|
with DefaultScope.overwrite_default_scope(scope): |
|
|
|
default_scope = DefaultScope.get_current_instance() |
|
|
|
if default_scope is not None: |
|
scope_name = default_scope.scope_name |
|
try: |
|
import_module(f'{scope_name}.registry') |
|
except (ImportError, AttributeError, ModuleNotFoundError): |
|
if scope in MODULE2PACKAGE: |
|
print( |
|
f'{scope} is not installed and its ' |
|
'modules will not be registered. If you ' |
|
'want to use modules defined in ' |
|
f'{scope}, Please install {scope} by ' |
|
f'`pip install {MODULE2PACKAGE[scope]}.') |
|
else: |
|
print( |
|
f'Failed to import `{scope}.registry` ' |
|
f'make sure the registry.py exists in `{scope}` ' |
|
'package.',) |
|
root = self._get_root_registry() |
|
registry = root._search_child(scope_name) |
|
if registry is None: |
|
print( |
|
f'Failed to search registry with scope "{scope_name}" ' |
|
f'in the "{root.name}" registry tree. ' |
|
f'As a workaround, the current "{self.name}" registry ' |
|
f'in "{self.scope}" is used to build instance. This ' |
|
'may cause unexpected failure when running the built ' |
|
f'modules. Please check whether "{scope_name}" is a ' |
|
'correct scope, or whether the registry is ' |
|
'initialized.',) |
|
registry = self |
|
else: |
|
registry = self |
|
yield registry |
|
|
|
def _get_root_registry(self) -> 'Registry': |
|
"""Return the root registry.""" |
|
root = self |
|
while root.parent is not None: |
|
root = root.parent |
|
return root |
|
|
|
def import_from_location(self) -> None: |
|
"""Import modules from the pre-defined locations in self._location.""" |
|
if not self._imported: |
|
|
|
if len(self._locations) == 0 and self.scope in MODULE2PACKAGE: |
|
print( |
|
f'The "{self.name}" registry in {self.scope} did not ' |
|
'set import location. Fallback to call ' |
|
f'`{self.scope}.utils.register_all_modules` ' |
|
'instead.',) |
|
try: |
|
module = import_module(f'{self.scope}.utils') |
|
except (ImportError, AttributeError, ModuleNotFoundError): |
|
if self.scope in MODULE2PACKAGE: |
|
print( |
|
f'{self.scope} is not installed and its ' |
|
'modules will not be registered. If you ' |
|
'want to use modules defined in ' |
|
f'{self.scope}, Please install {self.scope} by ' |
|
f'`pip install {MODULE2PACKAGE[self.scope]}.',) |
|
else: |
|
print( |
|
f'Failed to import {self.scope} and register ' |
|
'its modules, please make sure you ' |
|
'have registered the module manually.',) |
|
else: |
|
module.register_all_modules(False) |
|
|
|
for loc in self._locations: |
|
import_module(loc) |
|
print( |
|
f"Modules of {self.scope}'s {self.name} registry have " |
|
f'been automatically imported from {loc}',) |
|
self._imported = True |
|
|
|
def get(self, key: str) -> Optional[Type]: |
|
"""Get the registry record. |
|
|
|
If `key`` represents the whole object name with its module |
|
information, for example, `mmengine.model.BaseModel`, ``get`` |
|
will directly return the class object :class:`BaseModel`. |
|
|
|
Otherwise, it will first parse ``key`` and check whether it |
|
contains a scope name. The logic to search for ``key``: |
|
|
|
- ``key`` does not contain a scope name, i.e., it is purely a module |
|
name like "ResNet": :meth:`get` will search for ``ResNet`` from the |
|
current registry to its parent or ancestors until finding it. |
|
|
|
- ``key`` contains a scope name and it is equal to the scope of the |
|
current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get` |
|
will only search for ``ResNet`` in the current registry. |
|
|
|
- ``key`` contains a scope name and it is not equal to the scope of |
|
the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the |
|
scope exists in its children, :meth:`get` will get "FCNet" from |
|
them. If not, :meth:`get` will first get the root registry and root |
|
registry call its own :meth:`get` method. |
|
|
|
Args: |
|
key (str): Name of the registered item, e.g., the class name in |
|
string format. |
|
|
|
Returns: |
|
Type or None: Return the corresponding class if ``key`` exists, |
|
otherwise return None. |
|
""" |
|
|
|
if not isinstance(key, str): |
|
raise TypeError( |
|
'The key argument of `Registry.get` must be a str, ' |
|
f'got {type(key)}') |
|
|
|
scope, real_key = self.split_scope_key(key) |
|
obj_cls = None |
|
registry_name = self.name |
|
scope_name = self.scope |
|
|
|
|
|
self.import_from_location() |
|
|
|
if scope is None or scope == self._scope: |
|
|
|
if real_key in self._module_dict: |
|
obj_cls = self._module_dict[real_key] |
|
elif scope is None: |
|
|
|
parent = self.parent |
|
while parent is not None: |
|
if real_key in parent._module_dict: |
|
obj_cls = parent._module_dict[real_key] |
|
registry_name = parent.name |
|
scope_name = parent.scope |
|
break |
|
parent = parent.parent |
|
else: |
|
|
|
try: |
|
import_module(f'{scope}.registry') |
|
print( |
|
f'Registry node of {scope} has been automatically ' |
|
'imported.',) |
|
except (ImportError, AttributeError, ModuleNotFoundError): |
|
print( |
|
f'Cannot auto import {scope}.registry, please check ' |
|
f'whether the package "{scope}" is installed correctly ' |
|
'or import the registry manually.',) |
|
|
|
if scope in self._children: |
|
obj_cls = self._children[scope].get(real_key) |
|
registry_name = self._children[scope].name |
|
scope_name = scope |
|
else: |
|
root = self._get_root_registry() |
|
|
|
if scope != root._scope and scope not in root._children: |
|
|
|
|
|
pass |
|
else: |
|
obj_cls = root.get(key) |
|
|
|
if obj_cls is None: |
|
try: |
|
obj_cls = get_object_from_string(key) |
|
except Exception: |
|
raise RuntimeError(f'Failed to get {key}') |
|
|
|
if obj_cls is not None: |
|
|
|
|
|
cls_name = getattr(obj_cls, '__name__', str(obj_cls)) |
|
return obj_cls |
|
|
|
def _search_child(self, scope: str) -> Optional['Registry']: |
|
"""Depth-first search for the corresponding registry in its children. |
|
|
|
Note that the method only search for the corresponding registry from |
|
the current registry. Therefore, if we want to search from the root |
|
registry, :meth:`_get_root_registry` should be called to get the |
|
root registry first. |
|
|
|
Args: |
|
scope (str): The scope name used for searching for its |
|
corresponding registry. |
|
|
|
Returns: |
|
Registry or None: Return the corresponding registry if ``scope`` |
|
exists, otherwise return None. |
|
""" |
|
if self._scope == scope: |
|
return self |
|
|
|
for child in self._children.values(): |
|
registry = child._search_child(scope) |
|
if registry is not None: |
|
return registry |
|
|
|
return None |
|
|
|
def build(self, cfg: dict, *args, **kwargs) -> Any: |
|
"""Build an instance. |
|
|
|
Build an instance by calling :attr:`build_func`. |
|
|
|
Args: |
|
cfg (dict): Config dict needs to be built. |
|
|
|
Returns: |
|
Any: The constructed object. |
|
""" |
|
return self.build_func(cfg, *args, **kwargs, registry=self) |
|
|
|
def _add_child(self, registry: 'Registry') -> None: |
|
"""Add a child for a registry. |
|
|
|
Args: |
|
registry (:obj:`Registry`): The ``registry`` will be added as a |
|
child of the ``self``. |
|
""" |
|
|
|
assert isinstance(registry, Registry) |
|
assert registry.scope is not None |
|
assert registry.scope not in self.children, \ |
|
f'scope {registry.scope} exists in {self.name} registry' |
|
self.children[registry.scope] = registry |
|
|
|
def _register_module(self, |
|
module: Type, |
|
module_name: Optional[Union[str, List[str]]] = None, |
|
force: bool = False) -> None: |
|
"""Register a module. |
|
|
|
Args: |
|
module (type): Module to be registered. Typically a class or a |
|
function, but generally all ``Callable`` are acceptable. |
|
module_name (str or list of str, optional): The module name to be |
|
registered. If not specified, the class name will be used. |
|
Defaults to None. |
|
force (bool): Whether to override an existing class with the same |
|
name. Defaults to False. |
|
""" |
|
if not callable(module): |
|
raise TypeError(f'module must be Callable, but got {type(module)}') |
|
|
|
if module_name is None: |
|
module_name = module.__name__ |
|
if isinstance(module_name, str): |
|
module_name = [module_name] |
|
for name in module_name: |
|
if not force and name in self._module_dict: |
|
existed_module = self.module_dict[name] |
|
raise KeyError(f'{name} is already registered in {self.name} ' |
|
f'at {existed_module.__module__}') |
|
self._module_dict[name] = module |
|
|
|
def register_module( |
|
self, |
|
name: Optional[Union[str, List[str]]] = None, |
|
force: bool = False, |
|
module: Optional[Type] = None) -> Union[type, Callable]: |
|
"""Register a module. |
|
|
|
A record will be added to ``self._module_dict``, whose key is the class |
|
name or the specified name, and value is the class itself. |
|
It can be used as a decorator or a normal function. |
|
|
|
Args: |
|
name (str or list of str, optional): The module name to be |
|
registered. If not specified, the class name will be used. |
|
force (bool): Whether to override an existing class with the same |
|
name. Defaults to False. |
|
module (type, optional): Module class or function to be registered. |
|
Defaults to None. |
|
""" |
|
if not isinstance(force, bool): |
|
raise TypeError(f'force must be a boolean, but got {type(force)}') |
|
|
|
|
|
if not (name is None or isinstance(name, str) or is_seq_of(name, str)): |
|
raise TypeError( |
|
'name must be None, an instance of str, or a sequence of str, ' |
|
f'but got {type(name)}') |
|
|
|
|
|
if module is not None: |
|
self._register_module(module=module, module_name=name, force=force) |
|
return module |
|
|
|
|
|
def _register(module): |
|
self._register_module(module=module, module_name=name, force=force) |
|
return module |
|
|
|
return _register |
|
|
|
|
|
def is_seq_of(seq: Any, |
|
expected_type: Union[Type, tuple], |
|
seq_type: Optional[Type] = None) -> bool: |
|
"""Check whether it is a sequence of some type. |
|
|
|
Args: |
|
seq (Sequence): The sequence to be checked. |
|
expected_type (type or tuple): Expected type of sequence items. |
|
seq_type (type, optional): Expected sequence type. Defaults to None. |
|
|
|
Returns: |
|
bool: Return True if ``seq`` is valid else False. |
|
""" |
|
if seq_type is None: |
|
exp_seq_type = abc.Sequence |
|
else: |
|
assert isinstance(seq_type, type) |
|
exp_seq_type = seq_type |
|
if not isinstance(seq, exp_seq_type): |
|
return False |
|
for item in seq: |
|
if not isinstance(item, expected_type): |
|
return False |
|
return True |
|
|
|
|
|
def get_object_from_string(obj_name: str): |
|
"""Get object from name. |
|
|
|
Args: |
|
obj_name (str): The name of the object. |
|
""" |
|
parts = iter(obj_name.split('.')) |
|
module_name = next(parts) |
|
|
|
while True: |
|
try: |
|
module = import_module(module_name) |
|
part = next(parts) |
|
|
|
|
|
obj = getattr(module, part, None) |
|
if obj is not None and not ismodule(obj): |
|
break |
|
module_name = f'{module_name}.{part}' |
|
except StopIteration: |
|
|
|
return module |
|
except ImportError: |
|
return None |
|
|
|
|
|
obj = module |
|
while True: |
|
try: |
|
obj = getattr(obj, part) |
|
part = next(parts) |
|
except StopIteration: |
|
return obj |
|
except AttributeError: |
|
return None |
|
|
|
def ismodule(object): |
|
"""Return true if the object is a module. |
|
|
|
Module objects provide these attributes: |
|
__cached__ pathname to byte compiled file |
|
__doc__ documentation string |
|
__file__ filename (missing for built-in modules)""" |
|
return isinstance(object, types.ModuleType) |