|
|
|
import copy |
|
import math |
|
import warnings |
|
import inspect |
|
from typing import Any, Optional, Union |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from ..configs.config.config import Config, ConfigDict |
|
from ..utils.registry import Registry |
|
from ..utils.manager import ManagerMixin |
|
|
|
|
|
WEIGHT_INITIALIZERS = Registry('weight initializer') |
|
|
|
@WEIGHT_INITIALIZERS.register_module(name='Pretrained') |
|
class PretrainedInit: |
|
"""Initialize module by loading a pretrained model. |
|
|
|
Args: |
|
checkpoint (str): the checkpoint file of the pretrained model should |
|
be load. |
|
prefix (str, optional): the prefix of a sub-module in the pretrained |
|
model. it is for loading a part of the pretrained model to |
|
initialize. For example, if we would like to only load the |
|
backbone of a detector model, we can set ``prefix='backbone.'``. |
|
Defaults to None. |
|
map_location (str): map tensors into proper locations. Defaults to cpu. |
|
""" |
|
|
|
def __init__(self, checkpoint, prefix=None, map_location='cpu'): |
|
self.checkpoint = checkpoint |
|
self.prefix = prefix |
|
self.map_location = map_location |
|
|
|
def __call__(self, module): |
|
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix, |
|
load_checkpoint, |
|
load_state_dict) |
|
if self.prefix is None: |
|
load_checkpoint( |
|
module, |
|
self.checkpoint, |
|
map_location=self.map_location, |
|
strict=False, |
|
logger='current') |
|
else: |
|
state_dict = _load_checkpoint_with_prefix( |
|
self.prefix, self.checkpoint, map_location=self.map_location) |
|
load_state_dict(module, state_dict, strict=False, logger='current') |
|
|
|
if hasattr(module, '_params_init_info'): |
|
update_init_info(module, init_info=self._get_init_info()) |
|
|
|
def _get_init_info(self): |
|
info = f'{self.__class__.__name__}: load from {self.checkpoint}' |
|
return info |
|
|
|
|
|
def update_init_info(module, init_info): |
|
"""Update the `_params_init_info` in the module if the value of parameters |
|
are changed. |
|
|
|
Args: |
|
module (obj:`nn.Module`): The module of PyTorch with a user-defined |
|
attribute `_params_init_info` which records the initialization |
|
information. |
|
init_info (str): The string that describes the initialization. |
|
""" |
|
assert hasattr( |
|
module, |
|
'_params_init_info'), f'Can not find `_params_init_info` in {module}' |
|
for name, param in module.named_parameters(): |
|
|
|
assert param in module._params_init_info, ( |
|
f'Find a new :obj:`Parameter` ' |
|
f'named `{name}` during executing the ' |
|
f'`init_weights` of ' |
|
f'`{module.__class__.__name__}`. ' |
|
f'Please do not add or ' |
|
f'replace parameters during executing ' |
|
f'the `init_weights`. ') |
|
|
|
|
|
|
|
mean_value = param.data.mean().cpu() |
|
if module._params_init_info[param]['tmp_mean_value'] != mean_value: |
|
module._params_init_info[param]['init_info'] = init_info |
|
module._params_init_info[param]['tmp_mean_value'] = mean_value |
|
|
|
|
|
def initialize(module, init_cfg): |
|
r"""Initialize a module. |
|
|
|
Args: |
|
module (``torch.nn.Module``): the module will be initialized. |
|
init_cfg (dict | list[dict]): initialization configuration dict to |
|
define initializer. OpenMMLab has implemented 6 initializers |
|
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, |
|
``Kaiming``, and ``Pretrained``. |
|
|
|
Example: |
|
>>> module = nn.Linear(2, 3, bias=True) |
|
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2) |
|
>>> initialize(module, init_cfg) |
|
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2)) |
|
>>> # define key ``'layer'`` for initializing layer with different |
|
>>> # configuration |
|
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1), |
|
dict(type='Constant', layer='Linear', val=2)] |
|
>>> initialize(module, init_cfg) |
|
>>> # define key``'override'`` to initialize some specific part in |
|
>>> # module |
|
>>> class FooNet(nn.Module): |
|
>>> def __init__(self): |
|
>>> super().__init__() |
|
>>> self.feat = nn.Conv2d(3, 16, 3) |
|
>>> self.reg = nn.Conv2d(16, 10, 3) |
|
>>> self.cls = nn.Conv2d(16, 5, 3) |
|
>>> model = FooNet() |
|
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d', |
|
>>> override=dict(type='Constant', name='reg', val=3, bias=4)) |
|
>>> initialize(model, init_cfg) |
|
>>> model = ResNet(depth=50) |
|
>>> # Initialize weights with the pretrained model. |
|
>>> init_cfg = dict(type='Pretrained', |
|
checkpoint='torchvision://resnet50') |
|
>>> initialize(model, init_cfg) |
|
>>> # Initialize weights of a sub-module with the specific part of |
|
>>> # a pretrained model by using "prefix". |
|
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\ |
|
>>> 'retinanet_r50_fpn_1x_coco/'\ |
|
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' |
|
>>> init_cfg = dict(type='Pretrained', |
|
checkpoint=url, prefix='backbone.') |
|
""" |
|
if not isinstance(init_cfg, (dict, list)): |
|
raise TypeError(f'init_cfg must be a dict or a list of dict, \ |
|
but got {type(init_cfg)}') |
|
|
|
if isinstance(init_cfg, dict): |
|
init_cfg = [init_cfg] |
|
|
|
for cfg in init_cfg: |
|
|
|
|
|
|
|
|
|
cp_cfg = copy.deepcopy(cfg) |
|
override = cp_cfg.pop('override', None) |
|
_initialize(module, cp_cfg) |
|
|
|
if override is not None: |
|
cp_cfg.pop('layer', None) |
|
_initialize_override(module, override, cp_cfg) |
|
else: |
|
|
|
pass |
|
|
|
|
|
def _initialize(module, cfg, wholemodule=False): |
|
func = build_from_cfg(cfg, WEIGHT_INITIALIZERS) |
|
|
|
|
|
|
|
func.wholemodule = wholemodule |
|
func(module) |
|
|
|
|
|
def _initialize_override(module, override, cfg): |
|
if not isinstance(override, (dict, list)): |
|
raise TypeError(f'override must be a dict or a list of dict, \ |
|
but got {type(override)}') |
|
|
|
override = [override] if isinstance(override, dict) else override |
|
|
|
for override_ in override: |
|
|
|
cp_override = copy.deepcopy(override_) |
|
name = cp_override.pop('name', None) |
|
if name is None: |
|
raise ValueError('`override` must contain the key "name",' |
|
f'but got {cp_override}') |
|
|
|
if not cp_override: |
|
cp_override.update(cfg) |
|
|
|
|
|
elif 'type' not in cp_override.keys(): |
|
raise ValueError( |
|
f'`override` need "type" key, but got {cp_override}') |
|
|
|
if hasattr(module, name): |
|
_initialize(getattr(module, name), cp_override, wholemodule=True) |
|
else: |
|
raise RuntimeError(f'module did not have attribute {name}, ' |
|
f'but init_cfg is {cp_override}.') |
|
|
|
|
|
def build_from_cfg( |
|
cfg: Union[dict, ConfigDict, Config], |
|
registry: Registry, |
|
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: |
|
"""Build a module from config dict when it is a class configuration, or |
|
call a function from config dict when it is a function configuration. |
|
|
|
If the global variable default scope (:obj:`DefaultScope`) exists, |
|
:meth:`build` will firstly get the responding registry and then call |
|
its own :meth:`build`. |
|
|
|
At least one of the ``cfg`` and ``default_args`` contains the key "type", |
|
which should be either str or class. If they all contain it, the key |
|
in ``cfg`` will be used because ``cfg`` has a high priority than |
|
``default_args`` that means if a key exists in both of them, the value of |
|
the key will be ``cfg[key]``. They will be merged first and the key "type" |
|
will be popped up and the remaining keys will be used as initialization |
|
arguments. |
|
|
|
Examples: |
|
>>> from mmengine import Registry, build_from_cfg |
|
>>> MODELS = Registry('models') |
|
>>> @MODELS.register_module() |
|
>>> class ResNet: |
|
>>> def __init__(self, depth, stages=4): |
|
>>> self.depth = depth |
|
>>> self.stages = stages |
|
>>> cfg = dict(type='ResNet', depth=50) |
|
>>> model = build_from_cfg(cfg, MODELS) |
|
>>> # Returns an instantiated object |
|
>>> @MODELS.register_module() |
|
>>> def resnet50(): |
|
>>> pass |
|
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) |
|
>>> # Return a result of the calling function |
|
|
|
Args: |
|
cfg (dict or ConfigDict or Config): Config dict. It should at least |
|
contain the key "type". |
|
registry (:obj:`Registry`): The registry to search the type from. |
|
default_args (dict or ConfigDict or Config, optional): Default |
|
initialization arguments. Defaults to None. |
|
|
|
Returns: |
|
object: The constructed object. |
|
""" |
|
if not isinstance(cfg, (dict, ConfigDict, Config)): |
|
raise TypeError( |
|
f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}') |
|
|
|
if 'type' not in cfg: |
|
if default_args is None or 'type' not in default_args: |
|
raise KeyError( |
|
'`cfg` or `default_args` must contain the key "type", ' |
|
f'but got {cfg}\n{default_args}') |
|
|
|
if not isinstance(registry, Registry): |
|
raise TypeError('registry must be a mmengine.Registry object, ' |
|
f'but got {type(registry)}') |
|
|
|
if not (isinstance(default_args, |
|
(dict, ConfigDict, Config)) or default_args is None): |
|
raise TypeError( |
|
'default_args should be a dict, ConfigDict, Config or None, ' |
|
f'but got {type(default_args)}') |
|
|
|
args = cfg.copy() |
|
if default_args is not None: |
|
for name, value in default_args.items(): |
|
args.setdefault(name, value) |
|
|
|
|
|
|
|
|
|
scope = args.pop('_scope_', None) |
|
with registry.switch_scope_and_registry(scope) as registry: |
|
obj_type = args.pop('type') |
|
if isinstance(obj_type, str): |
|
obj_cls = registry.get(obj_type) |
|
if obj_cls is None: |
|
raise KeyError( |
|
f'{obj_type} is not in the {registry.scope}::{registry.name} registry. ' |
|
f'Please check whether the value of `{obj_type}` is ' |
|
'correct or it was registered as expected. More details ' |
|
'can be found at ' |
|
'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' |
|
) |
|
|
|
elif callable(obj_type): |
|
obj_cls = obj_type |
|
else: |
|
raise TypeError( |
|
f'type must be a str or valid type, but got {type(obj_type)}') |
|
|
|
|
|
|
|
|
|
if inspect.isclass(obj_cls) and \ |
|
issubclass(obj_cls, ManagerMixin): |
|
obj = obj_cls.get_instance(**args) |
|
else: |
|
obj = obj_cls(**args) |
|
return obj |
|
|
|
|
|
def constant_init(module, val, bias=0): |
|
if hasattr(module, 'weight') and module.weight is not None: |
|
nn.init.constant_(module.weight, val) |
|
if hasattr(module, 'bias') and module.bias is not None: |
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
def normal_init(module, mean=0, std=1, bias=0): |
|
if hasattr(module, 'weight') and module.weight is not None: |
|
nn.init.normal_(module.weight, mean, std) |
|
if hasattr(module, 'bias') and module.bias is not None: |
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
def trunc_normal_init(module: nn.Module, |
|
mean: float = 0, |
|
std: float = 1, |
|
a: float = -2, |
|
b: float = 2, |
|
bias: float = 0) -> None: |
|
if hasattr(module, 'weight') and module.weight is not None: |
|
trunc_normal_(module.weight, mean, std, a, b) |
|
if hasattr(module, 'bias') and module.bias is not None: |
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
def kaiming_init(module, |
|
a=0, |
|
mode='fan_out', |
|
nonlinearity='relu', |
|
bias=0, |
|
distribution='normal'): |
|
assert distribution in ['uniform', 'normal'] |
|
if hasattr(module, 'weight') and module.weight is not None: |
|
if distribution == 'uniform': |
|
nn.init.kaiming_uniform_( |
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity) |
|
else: |
|
nn.init.kaiming_normal_( |
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity) |
|
if hasattr(module, 'bias') and module.bias is not None: |
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
def trunc_normal_(tensor: Tensor, |
|
mean: float = 0., |
|
std: float = 1., |
|
a: float = -2., |
|
b: float = 2.) -> Tensor: |
|
r"""Fills the input Tensor with values drawn from a truncated normal |
|
distribution. The values are effectively drawn from the normal distribution |
|
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside |
|
:math:`[a, b]` redrawn until they are within the bounds. The method used |
|
for generating the random values works best when :math:`a \leq \text{mean} |
|
\leq b`. |
|
|
|
Modified from |
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py |
|
|
|
Args: |
|
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`. |
|
mean (float): the mean of the normal distribution. |
|
std (float): the standard deviation of the normal distribution. |
|
a (float): the minimum cutoff value. |
|
b (float): the maximum cutoff value. |
|
""" |
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|
|
|
|
|
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, |
|
b: float) -> Tensor: |
|
|
|
|
|
|
|
|
|
def norm_cdf(x): |
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn( |
|
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' |
|
'The distribution of values may be incorrect.', |
|
stacklevel=2) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
lower = norm_cdf((a - mean) / std) |
|
upper = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * lower - 1, 2 * upper - 1) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
return tensor |