|
import copy |
|
from abc import ABCMeta |
|
from collections import defaultdict |
|
from typing import Iterable, List, Optional, Union, Callable |
|
import warnings |
|
from inspect import getfullargspec |
|
import functools |
|
import torch.nn as nn |
|
|
|
from .utils import is_model_wrapper |
|
from .weight_init import PretrainedInit, initialize, update_init_info |
|
from ..utils.activation import build_dropout |
|
from ..utils.registry import MODELS |
|
|
|
|
|
class BaseModule(nn.Module, metaclass=ABCMeta): |
|
"""Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of |
|
``torch.nn.Module`` with additional functionality of parameter |
|
initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly |
|
adds three attributes. |
|
|
|
- ``init_cfg``: the config to control the initialization. |
|
- ``init_weights``: The function of parameter initialization and recording |
|
initialization information. |
|
- ``_params_init_info``: Used to track the parameter initialization |
|
information. This attribute only exists during executing the |
|
``init_weights``. |
|
|
|
Note: |
|
:obj:`PretrainedInit` has a higher priority than any other |
|
initializer. The loaded pretrained weights will overwrite |
|
the previous initialized weights. |
|
|
|
Args: |
|
init_cfg (dict or List[dict], optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, init_cfg: Union[dict, List[dict], None] = None): |
|
"""Initialize BaseModule, inherited from `torch.nn.Module`""" |
|
|
|
|
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self._is_init = False |
|
|
|
self.init_cfg = copy.deepcopy(init_cfg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def is_init(self): |
|
return self._is_init |
|
|
|
@is_init.setter |
|
def is_init(self, value): |
|
self._is_init = value |
|
|
|
def init_weights(self): |
|
"""Initialize the weights.""" |
|
|
|
is_top_level_module = False |
|
|
|
if not hasattr(self, '_params_init_info'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._params_init_info = defaultdict(dict) |
|
is_top_level_module = True |
|
|
|
|
|
|
|
|
|
|
|
for name, param in self.named_parameters(): |
|
self._params_init_info[param][ |
|
'init_info'] = f'The value is the same before and ' \ |
|
f'after calling `init_weights` ' \ |
|
f'of {self.__class__.__name__} ' |
|
self._params_init_info[param][ |
|
'tmp_mean_value'] = param.data.mean().cpu() |
|
|
|
|
|
|
|
|
|
|
|
for sub_module in self.modules(): |
|
sub_module._params_init_info = self._params_init_info |
|
|
|
module_name = self.__class__.__name__ |
|
if not self._is_init: |
|
if self.init_cfg: |
|
|
|
init_cfgs = self.init_cfg |
|
if isinstance(self.init_cfg, dict): |
|
init_cfgs = [self.init_cfg] |
|
|
|
|
|
|
|
|
|
|
|
other_cfgs = [] |
|
pretrained_cfg = [] |
|
for init_cfg in init_cfgs: |
|
assert isinstance(init_cfg, dict) |
|
if (init_cfg['type'] == 'Pretrained' |
|
or init_cfg['type'] is PretrainedInit): |
|
pretrained_cfg.append(init_cfg) |
|
else: |
|
other_cfgs.append(init_cfg) |
|
|
|
initialize(self, other_cfgs) |
|
|
|
for m in self.children(): |
|
if is_model_wrapper(m) and not hasattr(m, 'init_weights'): |
|
m = m.module |
|
if hasattr(m, 'init_weights') and not getattr( |
|
m, 'is_init', False): |
|
m.init_weights() |
|
|
|
update_init_info( |
|
m, |
|
init_info=f'Initialized by ' |
|
f'user-defined `init_weights`' |
|
f' in {m.__class__.__name__} ') |
|
if self.init_cfg and pretrained_cfg: |
|
initialize(self, pretrained_cfg) |
|
self._is_init = True |
|
|
|
if is_top_level_module: |
|
self._dump_init_info() |
|
|
|
for sub_module in self.modules(): |
|
del sub_module._params_init_info |
|
|
|
def __repr__(self): |
|
s = super().__repr__() |
|
if self.init_cfg: |
|
s += f'\ninit_cfg={self.init_cfg}' |
|
return s |
|
|
|
|
|
def deprecated_api_warning(name_dict: dict, |
|
cls_name: Optional[str] = None) -> Callable: |
|
"""A decorator to check if some arguments are deprecate and try to replace |
|
deprecate src_arg_name to dst_arg_name. |
|
|
|
Args: |
|
name_dict(dict): |
|
key (str): Deprecate argument names. |
|
val (str): Expected argument names. |
|
|
|
Returns: |
|
func: New function. |
|
""" |
|
|
|
def api_warning_wrapper(old_func): |
|
|
|
@functools.wraps(old_func) |
|
def new_func(*args, **kwargs): |
|
|
|
args_info = getfullargspec(old_func) |
|
|
|
func_name = old_func.__name__ |
|
if cls_name is not None: |
|
func_name = f'{cls_name}.{func_name}' |
|
if args: |
|
arg_names = args_info.args[:len(args)] |
|
for src_arg_name, dst_arg_name in name_dict.items(): |
|
if src_arg_name in arg_names: |
|
warnings.warn( |
|
f'"{src_arg_name}" is deprecated in ' |
|
f'`{func_name}`, please use "{dst_arg_name}" ' |
|
'instead', DeprecationWarning) |
|
arg_names[arg_names.index(src_arg_name)] = dst_arg_name |
|
if kwargs: |
|
for src_arg_name, dst_arg_name in name_dict.items(): |
|
if src_arg_name in kwargs: |
|
assert dst_arg_name not in kwargs, ( |
|
f'The expected behavior is to replace ' |
|
f'the deprecated key `{src_arg_name}` to ' |
|
f'new key `{dst_arg_name}`, but got them ' |
|
f'in the arguments at the same time, which ' |
|
f'is confusing. `{src_arg_name} will be ' |
|
f'deprecated in the future, please ' |
|
f'use `{dst_arg_name}` instead.') |
|
|
|
warnings.warn( |
|
f'"{src_arg_name}" is deprecated in ' |
|
f'`{func_name}`, please use "{dst_arg_name}" ' |
|
'instead', DeprecationWarning) |
|
kwargs[dst_arg_name] = kwargs.pop(src_arg_name) |
|
|
|
|
|
output = old_func(*args, **kwargs) |
|
return output |
|
|
|
return new_func |
|
|
|
return api_warning_wrapper |
|
|
|
|
|
@MODELS.register_module() |
|
class MultiheadAttention(BaseModule): |
|
"""A wrapper for ``torch.nn.MultiheadAttention``. |
|
|
|
This module implements MultiheadAttention with identity connection, |
|
and positional encoding is also passed as input. |
|
|
|
Args: |
|
embed_dims (int): The embedding dimension. |
|
num_heads (int): Parallel attention heads. |
|
attn_drop (float): A Dropout layer on attn_output_weights. |
|
Default: 0.0. |
|
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. |
|
Default: 0.0. |
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used |
|
when adding the shortcut. |
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
|
Default: None. |
|
batch_first (bool): When it is True, Key, Query and Value are shape of |
|
(batch, n, embed_dim), otherwise (n, batch, embed_dim). |
|
Default to False. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
num_heads, |
|
attn_drop=0., |
|
proj_drop=0., |
|
dropout_layer=dict(type='Dropout', drop_prob=0.), |
|
init_cfg=None, |
|
batch_first=False, |
|
**kwargs): |
|
super().__init__(init_cfg) |
|
if 'dropout' in kwargs: |
|
warnings.warn( |
|
'The arguments `dropout` in MultiheadAttention ' |
|
'has been deprecated, now you can separately ' |
|
'set `attn_drop`(float), proj_drop(float), ' |
|
'and `dropout_layer`(dict) ', DeprecationWarning) |
|
attn_drop = kwargs['dropout'] |
|
dropout_layer['drop_prob'] = kwargs.pop('dropout') |
|
|
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.batch_first = batch_first |
|
|
|
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, |
|
**kwargs) |
|
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
self.dropout_layer = build_dropout( |
|
dropout_layer) if dropout_layer else nn.Identity() |
|
|
|
@deprecated_api_warning({'residual': 'identity'}, |
|
cls_name='MultiheadAttention') |
|
def forward(self, |
|
query, |
|
key=None, |
|
value=None, |
|
identity=None, |
|
query_pos=None, |
|
key_pos=None, |
|
attn_mask=None, |
|
key_padding_mask=None, |
|
**kwargs): |
|
"""Forward function for `MultiheadAttention`. |
|
|
|
**kwargs allow passing a more general data flow when combining |
|
with other operations in `transformerlayer`. |
|
|
|
Args: |
|
query (Tensor): The input query with shape [num_queries, bs, |
|
embed_dims] if self.batch_first is False, else |
|
[bs, num_queries embed_dims]. |
|
key (Tensor): The key tensor with shape [num_keys, bs, |
|
embed_dims] if self.batch_first is False, else |
|
[bs, num_keys, embed_dims] . |
|
If None, the ``query`` will be used. Defaults to None. |
|
value (Tensor): The value tensor with same shape as `key`. |
|
Same in `nn.MultiheadAttention.forward`. Defaults to None. |
|
If None, the `key` will be used. |
|
identity (Tensor): This tensor, with the same shape as x, |
|
will be used for the identity link. |
|
If None, `x` will be used. Defaults to None. |
|
query_pos (Tensor): The positional encoding for query, with |
|
the same shape as `x`. If not None, it will |
|
be added to `x` before forward function. Defaults to None. |
|
key_pos (Tensor): The positional encoding for `key`, with the |
|
same shape as `key`. Defaults to None. If not None, it will |
|
be added to `key` before forward function. If None, and |
|
`query_pos` has the same shape as `key`, then `query_pos` |
|
will be used for `key_pos`. Defaults to None. |
|
attn_mask (Tensor): ByteTensor mask with shape [num_queries, |
|
num_keys]. Same in `nn.MultiheadAttention.forward`. |
|
Defaults to None. |
|
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. |
|
Defaults to None. |
|
|
|
Returns: |
|
Tensor: forwarded results with shape |
|
[num_queries, bs, embed_dims] |
|
if self.batch_first is False, else |
|
[bs, num_queries embed_dims]. |
|
""" |
|
|
|
if key is None: |
|
key = query |
|
if value is None: |
|
value = key |
|
if identity is None: |
|
identity = query |
|
if key_pos is None: |
|
if query_pos is not None: |
|
|
|
if query_pos.shape == key.shape: |
|
key_pos = query_pos |
|
if query_pos is not None: |
|
query = query + query_pos |
|
if key_pos is not None: |
|
key = key + key_pos |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.batch_first: |
|
query = query.transpose(0, 1) |
|
key = key.transpose(0, 1) |
|
value = value.transpose(0, 1) |
|
|
|
out = self.attn( |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_mask=attn_mask, |
|
key_padding_mask=key_padding_mask)[0] |
|
|
|
if self.batch_first: |
|
out = out.transpose(0, 1) |
|
|
|
return identity + self.dropout_layer(self.proj_drop(out)) |
|
|
|
|
|
class ModuleList(BaseModule, nn.ModuleList): |
|
"""ModuleList in openmmlab. |
|
|
|
Ensures that all modules in ``ModuleList`` have a different initialization |
|
strategy than the outer model |
|
|
|
Args: |
|
modules (iterable, optional): An iterable of modules to add. |
|
init_cfg (dict, optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, |
|
modules: Optional[Iterable] = None, |
|
init_cfg: Optional[dict] = None): |
|
BaseModule.__init__(self, init_cfg) |
|
nn.ModuleList.__init__(self, modules) |
|
|
|
|
|
class Sequential(BaseModule, nn.Sequential): |
|
"""Sequential module in openmmlab. |
|
|
|
Ensures that all modules in ``Sequential`` have a different initialization |
|
strategy than the outer model |
|
|
|
Args: |
|
init_cfg (dict, optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, *args, init_cfg: Optional[dict] = None): |
|
BaseModule.__init__(self, init_cfg) |
|
nn.Sequential.__init__(self, *args) |