Tim Mayer commited on
Commit
e98bd8c
·
0 Parent(s):

Projectstructure + gitignore prepared

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +12 -0
  2. build/lib/segformer_plusplus/__init__.py +4 -0
  3. build/lib/segformer_plusplus/build_model.py +108 -0
  4. build/lib/segformer_plusplus/configs/__init__.py +1 -0
  5. build/lib/segformer_plusplus/configs/segformer_mit_b0.py +28 -0
  6. build/lib/segformer_plusplus/configs/segformer_mit_b1.py +8 -0
  7. build/lib/segformer_plusplus/configs/segformer_mit_b2.py +6 -0
  8. build/lib/segformer_plusplus/configs/segformer_mit_b3.py +6 -0
  9. build/lib/segformer_plusplus/configs/segformer_mit_b4.py +6 -0
  10. build/lib/segformer_plusplus/configs/segformer_mit_b5.py +6 -0
  11. build/lib/segformer_plusplus/model/__init__.py +1 -0
  12. build/lib/segformer_plusplus/model/backbone/__init__.py +3 -0
  13. build/lib/segformer_plusplus/model/backbone/mit.py +479 -0
  14. build/lib/segformer_plusplus/model/head/__init__.py +3 -0
  15. build/lib/segformer_plusplus/model/head/segformer_head.py +95 -0
  16. build/lib/segformer_plusplus/random_benchmark.py +61 -0
  17. build/lib/segformer_plusplus/utils/__init__.py +12 -0
  18. build/lib/segformer_plusplus/utils/benchmark.py +76 -0
  19. build/lib/segformer_plusplus/utils/embed.py +330 -0
  20. build/lib/segformer_plusplus/utils/imagenet_weights.py +8 -0
  21. build/lib/segformer_plusplus/utils/registry.py +6 -0
  22. build/lib/segformer_plusplus/utils/shape_convert.py +107 -0
  23. build/lib/segformer_plusplus/utils/tome_presets.py +20 -0
  24. build/lib/segformer_plusplus/utils/wrappers.py +51 -0
  25. cityscapes_prediction_output_reference.txt +0 -0
  26. segformer_plusplus.egg-info/PKG-INFO +11 -0
  27. segformer_plusplus.egg-info/SOURCES.txt +29 -0
  28. segformer_plusplus.egg-info/dependency_links.txt +1 -0
  29. segformer_plusplus.egg-info/requires.txt +2 -0
  30. segformer_plusplus.egg-info/top_level.txt +1 -0
  31. segformer_plusplus/Registry/default_scope.py +95 -0
  32. segformer_plusplus/Registry/registry.py +735 -0
  33. segformer_plusplus/__init__.py +4 -0
  34. segformer_plusplus/build_model.py +107 -0
  35. segformer_plusplus/cityscape_benchmark.py +117 -0
  36. segformer_plusplus/configs/__init__.py +1 -0
  37. segformer_plusplus/configs/config/config.py +1545 -0
  38. segformer_plusplus/configs/config/lazy.py +267 -0
  39. segformer_plusplus/configs/config/utils.py +647 -0
  40. segformer_plusplus/configs/segformer_mit_b0.py +28 -0
  41. segformer_plusplus/configs/segformer_mit_b1.py +8 -0
  42. segformer_plusplus/configs/segformer_mit_b2.py +6 -0
  43. segformer_plusplus/configs/segformer_mit_b3.py +6 -0
  44. segformer_plusplus/configs/segformer_mit_b4.py +6 -0
  45. segformer_plusplus/configs/segformer_mit_b5.py +6 -0
  46. segformer_plusplus/model/__init__.py +1 -0
  47. segformer_plusplus/model/backbone/__init__.py +3 -0
  48. segformer_plusplus/model/backbone/mit.py +477 -0
  49. segformer_plusplus/model/base_module.py +390 -0
  50. segformer_plusplus/model/head/__init__.py +3 -0
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ *.pth
6
+ *.pt
7
+ *.log
8
+ *.tmp
9
+ .env
10
+ .vscode/
11
+ .idea/
12
+ .DS_Store
build/lib/segformer_plusplus/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .build_model import create_model, create_custom_model
2
+ from .random_benchmark import random_benchmark
3
+
4
+ __all__ = ['create_model', 'create_custom_model', 'random_benchmark']
build/lib/segformer_plusplus/build_model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from mmengine import registry
4
+ from mmengine.config import Config
5
+ from mmengine.model import BaseModule
6
+
7
+ from .utils import MODELS, imagenet_weights
8
+ from .utils import tome_presets
9
+
10
+
11
+ class SegFormer(BaseModule):
12
+ """
13
+ This class represents a SegFormer model that allows for the application of token merging.
14
+
15
+ Attributes:
16
+ backbone (BaseModule): MixVisionTransformer backbone
17
+ decode_head (BaseModule): SegFormer head
18
+
19
+ """
20
+ def __init__(self, cfg):
21
+ """
22
+ Initialize the SegFormer model.
23
+
24
+ Args:
25
+ cfg (Config): an mmengine Config object, which defines the backbone, head and token merging strategy used.
26
+
27
+ """
28
+ super().__init__()
29
+ self.backbone = registry.build_model_from_cfg(cfg.backbone, registry=MODELS)
30
+ self.decode_head = registry.build_model_from_cfg(cfg.decode_head, registry=MODELS)
31
+
32
+ def forward(self, x):
33
+ """
34
+ Forward pass of the model.
35
+
36
+ Args:
37
+ x (torch.Tensor): input tensor of shape [B, C, H, W]
38
+
39
+ Returns:
40
+ torch.Tensor: output tensor
41
+
42
+ """
43
+ x = self.backbone(x)
44
+ x = self.decode_head(x)
45
+ return x
46
+
47
+
48
+ def create_model(
49
+ backbone: str = 'b0',
50
+ tome_strategy: str = None,
51
+ out_channels: int = 19,
52
+ pretrained: bool = False,
53
+ ):
54
+ """
55
+ Create a SegFormer model using the predefined SegFormer backbones from the MiT series (b0-b5).
56
+
57
+ Args:
58
+ backbone (str): backbone name (e.g. 'b0')
59
+ tome_strategy (str | list(dict)): select strategy from presets ('bsm_hq', 'bsm_fast', 'n2d_2x2') or define a
60
+ custom strategy using a list, that contains of dictionaries, in which the strategies for the stage are
61
+ defined
62
+ out_channels (int): number of output channels (e.g. 19 for the cityscapes semantic segmentation task)
63
+ pretrained: use pretrained (imagenet) weights
64
+
65
+ Returns:
66
+ BaseModule: SegFormer model
67
+
68
+ """
69
+ backbone = backbone.lower()
70
+ assert backbone in [f'b{i}' for i in range(6)]
71
+
72
+ wd = os.path.dirname(os.path.abspath(__file__))
73
+
74
+ cfg = Config.fromfile(os.path.join(wd, 'configs', f'segformer_mit_{backbone}.py'))
75
+
76
+ cfg.decode_head.out_channels = out_channels
77
+
78
+ if tome_strategy is not None:
79
+ if tome_strategy not in list(tome_presets.keys()):
80
+ print("Using custom merging strategy.")
81
+ cfg.backbone.tome_cfg = tome_presets[tome_strategy]
82
+
83
+ # load imagenet weights
84
+ if pretrained:
85
+ cfg.backbone.init_cfg = dict(type='Pretrained', checkpoint=imagenet_weights[backbone])
86
+
87
+ return SegFormer(cfg)
88
+
89
+
90
+ def create_custom_model(
91
+ model_cfg: Config,
92
+ tome_strategy: list[dict] = None,
93
+ ):
94
+ """
95
+ Create a SegFormer model with customizable backbone and head.
96
+
97
+ Args:
98
+ model_cfg (Config): backbone name (e.g. 'b0')
99
+ tome_strategy (list(dict)): custom token merging strategy
100
+
101
+ Returns:
102
+ BaseModule: SegFormer model
103
+
104
+ """
105
+ if tome_strategy is not None:
106
+ model_cfg.backbone.tome_cfg = tome_strategy
107
+
108
+ return SegFormer(model_cfg)
build/lib/segformer_plusplus/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __all__ = []
build/lib/segformer_plusplus/configs/segformer_mit_b0.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
2
+ backbone = dict(
3
+ type='MixVisionTransformer',
4
+ in_channels=3,
5
+ embed_dims=32,
6
+ num_stages=4,
7
+ num_layers=[2, 2, 2, 2],
8
+ num_heads=[1, 2, 5, 8],
9
+ patch_sizes=[7, 3, 3, 3],
10
+ sr_ratios=[8, 4, 2, 1],
11
+ out_indices=(0, 1, 2, 3),
12
+ mlp_ratio=4,
13
+ qkv_bias=True,
14
+ drop_rate=0.0,
15
+ attn_drop_rate=0.0,
16
+ drop_path_rate=0.1
17
+ )
18
+ decode_head = dict(
19
+ type='SegformerHead',
20
+ in_channels=[32, 64, 160, 256],
21
+ in_index=[0, 1, 2, 3],
22
+ channels=256,
23
+ dropout_ratio=0.1,
24
+ out_channels=19,
25
+ norm_cfg=norm_cfg,
26
+ align_corners=False,
27
+ interpolate_mode='bilinear'
28
+ )
build/lib/segformer_plusplus/configs/segformer_mit_b1.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b0.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ )
6
+ decode_head = dict(
7
+ in_channels=[64, 128, 320, 512]
8
+ )
build/lib/segformer_plusplus/configs/segformer_mit_b2.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b1.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ num_layers=[3, 4, 6, 3]
6
+ )
build/lib/segformer_plusplus/configs/segformer_mit_b3.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b1.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ num_layers=[3, 4, 18, 3]
6
+ )
build/lib/segformer_plusplus/configs/segformer_mit_b4.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b1.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ num_layers=[3, 8, 27, 3]
6
+ )
build/lib/segformer_plusplus/configs/segformer_mit_b5.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b1.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ num_layers=[3, 6, 40, 3]
6
+ )
build/lib/segformer_plusplus/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __all__ = []
build/lib/segformer_plusplus/model/backbone/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .mit import MixVisionTransformer
2
+
3
+ __all__ = ['MixVisionTransformer']
build/lib/segformer_plusplus/model/backbone/mit.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.utils.checkpoint as cp
7
+ from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
8
+ from mmcv.cnn.bricks.drop import build_dropout
9
+ from mmcv.cnn.bricks.transformer import MultiheadAttention
10
+ from mmengine.model import BaseModule, ModuleList, Sequential
11
+ from mmengine.model.weight_init import (constant_init, normal_init,
12
+ trunc_normal_init)
13
+ from tomesd.merge import bipartite_soft_matching_random2d
14
+
15
+ from ...utils import PatchEmbed
16
+ from ...utils import nchw_to_nlc, nlc_to_nchw
17
+ from ...utils import MODELS
18
+
19
+ class MixFFN(BaseModule):
20
+ """An implementation of MixFFN of Segformer.
21
+
22
+ The differences between MixFFN & FFN:
23
+ 1. Use 1X1 Conv to replace Linear layer.
24
+ 2. Introduce 3X3 Conv to encode positional information.
25
+ Args:
26
+ embed_dims (int): The feature dimension. Same as
27
+ `MultiheadAttention`. Defaults: 256.
28
+ feedforward_channels (int): The hidden dimension of FFNs.
29
+ Defaults: 1024.
30
+ act_cfg (dict, optional): The activation config for FFNs.
31
+ Default: dict(type='ReLU')
32
+ ffn_drop (float, optional): Probability of an element to be
33
+ zeroed in FFN. Default 0.0.
34
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
35
+ when adding the shortcut.
36
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
37
+ Default: None.
38
+ """
39
+
40
+ def __init__(self,
41
+ embed_dims,
42
+ feedforward_channels,
43
+ act_cfg=dict(type='GELU'),
44
+ ffn_drop=0.,
45
+ dropout_layer=None,
46
+ init_cfg=None):
47
+ super().__init__(init_cfg)
48
+
49
+ self.embed_dims = embed_dims
50
+ self.feedforward_channels = feedforward_channels
51
+ self.act_cfg = act_cfg
52
+ self.activate = build_activation_layer(act_cfg)
53
+
54
+ in_channels = embed_dims
55
+ fc1 = Conv2d(
56
+ in_channels=in_channels,
57
+ out_channels=feedforward_channels,
58
+ kernel_size=1,
59
+ stride=1,
60
+ bias=True)
61
+ # 3x3 depth wise conv to provide positional encode information
62
+ pe_conv = Conv2d(
63
+ in_channels=feedforward_channels,
64
+ out_channels=feedforward_channels,
65
+ kernel_size=3,
66
+ stride=1,
67
+ padding=(3 - 1) // 2,
68
+ bias=True,
69
+ groups=feedforward_channels)
70
+ fc2 = Conv2d(
71
+ in_channels=feedforward_channels,
72
+ out_channels=in_channels,
73
+ kernel_size=1,
74
+ stride=1,
75
+ bias=True)
76
+ drop = nn.Dropout(ffn_drop)
77
+ layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
78
+ self.layers = Sequential(*layers)
79
+ self.dropout_layer = build_dropout(
80
+ dropout_layer) if dropout_layer else torch.nn.Identity()
81
+
82
+ def forward(self, x, hw_shape, identity=None):
83
+ out = nlc_to_nchw(x, hw_shape)
84
+ out = self.layers(out)
85
+ out = nchw_to_nlc(out)
86
+ if identity is None:
87
+ identity = x
88
+ return identity + self.dropout_layer(out)
89
+
90
+
91
+ class EfficientMultiheadAttention(MultiheadAttention):
92
+ """An implementation of Efficient Multi-head Attention of Segformer.
93
+
94
+ This module is modified from MultiheadAttention which is a module from
95
+ mmcv.cnn.bricks.transformer.
96
+ Args:
97
+ embed_dims (int): The embedding dimension.
98
+ num_heads (int): Parallel attention heads.
99
+ attn_drop (float): A Dropout layer on attn_output_weights.
100
+ Default: 0.0.
101
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
102
+ Default: 0.0.
103
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
104
+ when adding the shortcut. Default: None.
105
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
106
+ Default: None.
107
+ batch_first (bool): Key, Query and Value are shape of
108
+ (batch, n, embed_dim)
109
+ or (n, batch, embed_dim). Default: False.
110
+ qkv_bias (bool): enable bias for qkv if True. Default True.
111
+ norm_cfg (dict): Config dict for normalization layer.
112
+ Default: dict(type='LN').
113
+ sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
114
+ Attention of Segformer. Default: 1.
115
+ """
116
+
117
+ def __init__(self,
118
+ embed_dims,
119
+ num_heads,
120
+ attn_drop=0.,
121
+ proj_drop=0.,
122
+ dropout_layer=None,
123
+ init_cfg=None,
124
+ batch_first=True,
125
+ qkv_bias=False,
126
+ tome_cfg=dict(),
127
+ norm_cfg=dict(type='LN'),
128
+ sr_ratio=1):
129
+ super().__init__(
130
+ embed_dims,
131
+ num_heads,
132
+ attn_drop,
133
+ proj_drop,
134
+ dropout_layer=dropout_layer,
135
+ init_cfg=init_cfg,
136
+ batch_first=batch_first,
137
+ bias=qkv_bias)
138
+
139
+ self.q_mode = tome_cfg.get('q_mode')
140
+ self.kv_mode = tome_cfg.get('kv_mode')
141
+ self.tome_cfg = tome_cfg
142
+
143
+ self.sr_ratio = sr_ratio
144
+ if sr_ratio > 1:
145
+ self.sr = Conv2d(
146
+ in_channels=embed_dims,
147
+ out_channels=embed_dims,
148
+ kernel_size=sr_ratio,
149
+ stride=sr_ratio)
150
+ # The ret[0] of build_norm_layer is norm name.
151
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
152
+
153
+ def forward(self, x, hw_shape, identity=None):
154
+ x_q = x
155
+
156
+ if self.sr_ratio > 1:
157
+ x_kv = nlc_to_nchw(x, hw_shape)
158
+ x_kv = self.sr(x_kv)
159
+ x_kv = nchw_to_nlc(x_kv)
160
+ x_kv = self.norm(x_kv)
161
+ else:
162
+ x_kv = x
163
+
164
+ # 2D Neighbour Merging KV
165
+ if self.kv_mode == 'n2d':
166
+ kv_hw_shape = (int(hw_shape[0] / self.sr_ratio), int(hw_shape[1] / self.sr_ratio))
167
+ x_kv = nlc_to_nchw(x_kv, kv_hw_shape)
168
+ x_kv = torch.nn.functional.avg_pool2d(x_kv, kernel_size=self.tome_cfg['kv_s'],
169
+ stride=self.tome_cfg['kv_s'],
170
+ ceil_mode=True)
171
+ x_kv = nchw_to_nlc(x_kv)
172
+
173
+ # Bipartite Soft Matching (tomesd) KV
174
+ if self.kv_mode == 'bsm':
175
+ w_kv = int(hw_shape[1] / self.sr_ratio)
176
+ h_kv = int(hw_shape[0] / self.sr_ratio)
177
+ merge, unmerge = bipartite_soft_matching_random2d(metric=x_kv, w=w_kv, h=h_kv,
178
+ r=int(x_kv.size()[1] * self.tome_cfg['kv_r']),
179
+ sx=self.tome_cfg['kv_sx'], sy=self.tome_cfg['kv_sy'],
180
+ no_rand=True)
181
+ x_kv = merge(x_kv)
182
+
183
+ if identity is None:
184
+ identity = x_q
185
+
186
+ # 1D Neighbor Merging Q
187
+ if self.q_mode == 'n1d':
188
+ x_q = x_q.transpose(-2, -1)
189
+ x_q = torch.nn.functional.avg_pool1d(x_q, kernel_size=self.tome_cfg['q_s'],
190
+ stride=self.tome_cfg['q_s'],
191
+ ceil_mode=True)
192
+ x_q = x_q.transpose(-2, -1)
193
+
194
+ # 2D Neighbor Merging Q
195
+ if self.q_mode == 'n2d':
196
+ reduced_hw = (int(torch.ceil(torch.tensor(hw_shape[0] / self.tome_cfg['q_s'][0]))),
197
+ int(torch.ceil(torch.tensor(hw_shape[1] / self.tome_cfg['q_s'][1]))))
198
+ x_q = nlc_to_nchw(x_q, hw_shape)
199
+ x_q = torch.nn.functional.avg_pool2d(x_q, kernel_size=self.tome_cfg['q_s'],
200
+ stride=self.tome_cfg['q_s'],
201
+ ceil_mode=True)
202
+ x_q = nchw_to_nlc(x_q)
203
+
204
+ # Bipartite Soft Matching (tomesd) Q
205
+ if self.q_mode == 'bsm':
206
+ merge, unmerge = bipartite_soft_matching_random2d(metric=x_q, w=hw_shape[1], h=hw_shape[0],
207
+ r=int(x_q.size()[1] * self.tome_cfg['q_r']),
208
+ sx=self.tome_cfg['q_sx'], sy=self.tome_cfg['q_sy'],
209
+ no_rand=True)
210
+ x_q = merge(x_q)
211
+
212
+ # Because the dataflow('key', 'query', 'value') of
213
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
214
+ # embed_dims), We should adjust the shape of dataflow from
215
+ # batch_first (batch, num_query, embed_dims) to num_query_first
216
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
217
+ # from num_query_first to batch_first.
218
+
219
+ if self.batch_first:
220
+ x_q = x_q.transpose(0, 1)
221
+ x_kv = x_kv.transpose(0, 1)
222
+ out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
223
+ if self.batch_first:
224
+ out = out.transpose(0, 1)
225
+
226
+ # Unmerging BSM (tome+tomesd)
227
+ if self.q_mode == 'bsm':
228
+ out = unmerge(out)
229
+
230
+ # Unmerging 1D Neighbour Merging
231
+ if self.q_mode == 'n1d':
232
+ out = out.transpose(-2, -1)
233
+ out = torch.nn.functional.interpolate(out, size=identity.size()[-2])
234
+ out = out.transpose(-2, -1)
235
+
236
+ # Unmerging 2D Neighbor Merging
237
+ if self.q_mode == 'n2d':
238
+ out = nlc_to_nchw(out, reduced_hw)
239
+ out = torch.nn.functional.interpolate(out, size=hw_shape)
240
+ out = nchw_to_nlc(out)
241
+
242
+ return identity + self.dropout_layer(self.proj_drop(out))
243
+
244
+
245
+ class TransformerEncoderLayer(BaseModule):
246
+ """Implements one encoder layer in Segformer.
247
+
248
+ Args:
249
+ embed_dims (int): The feature dimension.
250
+ num_heads (int): Parallel attention heads.
251
+ feedforward_channels (int): The hidden dimension for FFNs.
252
+ drop_rate (float): Probability of an element to be zeroed.
253
+ after the feed forward layer. Default 0.0.
254
+ attn_drop_rate (float): The drop out rate for attention layer.
255
+ Default 0.0.
256
+ drop_path_rate (float): stochastic depth rate. Default 0.0.
257
+ qkv_bias (bool): enable bias for qkv if True.
258
+ Default: True.
259
+ act_cfg (dict): The activation config for FFNs.
260
+ Default: dict(type='GELU').
261
+ norm_cfg (dict): Config dict for normalization layer.
262
+ Default: dict(type='LN').
263
+ batch_first (bool): Key, Query and Value are shape of
264
+ (batch, n, embed_dim)
265
+ or (n, batch, embed_dim). Default: False.
266
+ init_cfg (dict, optional): Initialization config dict.
267
+ Default:None.
268
+ sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
269
+ Attention of Segformer. Default: 1.
270
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
271
+ some memory while slowing down the training speed. Default: False.
272
+ """
273
+
274
+ def __init__(self,
275
+ embed_dims,
276
+ num_heads,
277
+ feedforward_channels,
278
+ drop_rate=0.,
279
+ attn_drop_rate=0.,
280
+ drop_path_rate=0.,
281
+ qkv_bias=True,
282
+ tome_cfg=dict(),
283
+ act_cfg=dict(type='GELU'),
284
+ norm_cfg=dict(type='LN'),
285
+ batch_first=True,
286
+ sr_ratio=1,
287
+ with_cp=False):
288
+ super().__init__()
289
+
290
+ # The ret[0] of build_norm_layer is norm name.
291
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
292
+
293
+ self.attn = EfficientMultiheadAttention(
294
+ embed_dims=embed_dims,
295
+ num_heads=num_heads,
296
+ attn_drop=attn_drop_rate,
297
+ proj_drop=drop_rate,
298
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
299
+ batch_first=batch_first,
300
+ qkv_bias=qkv_bias,
301
+ tome_cfg=tome_cfg,
302
+ norm_cfg=norm_cfg,
303
+ sr_ratio=sr_ratio)
304
+
305
+ # The ret[0] of build_norm_layer is norm name.
306
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
307
+
308
+ self.ffn = MixFFN(
309
+ embed_dims=embed_dims,
310
+ feedforward_channels=feedforward_channels,
311
+ ffn_drop=drop_rate,
312
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
313
+ act_cfg=act_cfg)
314
+
315
+ self.with_cp = with_cp
316
+
317
+ def forward(self, x, hw_shape):
318
+
319
+ def _inner_forward(x):
320
+ x = self.attn(self.norm1(x), hw_shape, identity=x)
321
+ x = self.ffn(self.norm2(x), hw_shape, identity=x)
322
+ return x
323
+
324
+ if self.with_cp and x.requires_grad:
325
+ x = cp.checkpoint(_inner_forward, x)
326
+ else:
327
+ x = _inner_forward(x)
328
+ return x
329
+
330
+
331
+ @MODELS.register_module()
332
+ class MixVisionTransformer(BaseModule):
333
+ """The backbone of Segformer.
334
+
335
+ This backbone is the implementation of `SegFormer: Simple and
336
+ Efficient Design for Semantic Segmentation with
337
+ Transformers <https://arxiv.org/abs/2105.15203>`_.
338
+ Args:
339
+ in_channels (int): Number of input channels. Default: 3.
340
+ embed_dims (int): Embedding dimension. Default: 768.
341
+ num_stags (int): The num of stages. Default: 4.
342
+ num_layers (Sequence[int]): The layer number of each transformer encode
343
+ layer. Default: [3, 4, 6, 3].
344
+ num_heads (Sequence[int]): The attention heads of each transformer
345
+ encode layer. Default: [1, 2, 4, 8].
346
+ patch_sizes (Sequence[int]): The patch_size of each overlapped patch
347
+ embedding. Default: [7, 3, 3, 3].
348
+ strides (Sequence[int]): The stride of each overlapped patch embedding.
349
+ Default: [4, 2, 2, 2].
350
+ sr_ratios (Sequence[int]): The spatial reduction rate of each
351
+ transformer encode layer. Default: [8, 4, 2, 1].
352
+ out_indices (Sequence[int] | int): Output from which stages.
353
+ Default: (0, 1, 2, 3).
354
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
355
+ Default: 4.
356
+ qkv_bias (bool): Enable bias for qkv if True. Default: True.
357
+ drop_rate (float): Probability of an element to be zeroed.
358
+ Default 0.0
359
+ attn_drop_rate (float): The drop out rate for attention layer.
360
+ Default 0.0
361
+ drop_path_rate (float): stochastic depth rate. Default 0.0
362
+ norm_cfg (dict): Config dict for normalization layer.
363
+ Default: dict(type='LN')
364
+ act_cfg (dict): The activation config for FFNs.
365
+ Default: dict(type='GELU').
366
+ pretrained (str, optional): model pretrained path. Default: None.
367
+ init_cfg (dict or list[dict], optional): Initialization config dict.
368
+ Default: None.
369
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
370
+ some memory while slowing down the training speed. Default: False.
371
+ """
372
+
373
+ def __init__(self,
374
+ in_channels=3,
375
+ embed_dims=64,
376
+ num_stages=4,
377
+ num_layers=[3, 4, 6, 3],
378
+ num_heads=[1, 2, 4, 8],
379
+ patch_sizes=[7, 3, 3, 3],
380
+ strides=[4, 2, 2, 2],
381
+ sr_ratios=[8, 4, 2, 1],
382
+ out_indices=(0, 1, 2, 3),
383
+ mlp_ratio=4,
384
+ qkv_bias=True,
385
+ drop_rate=0.,
386
+ attn_drop_rate=0.,
387
+ drop_path_rate=0.,
388
+ tome_cfg=[dict(), dict(), dict(), dict()],
389
+ act_cfg=dict(type='GELU'),
390
+ norm_cfg=dict(type='LN', eps=1e-6),
391
+ init_cfg=None,
392
+ with_cp=False,
393
+ down_sample=False):
394
+ super().__init__(init_cfg=init_cfg)
395
+
396
+ self.embed_dims = embed_dims
397
+ self.num_stages = num_stages
398
+ self.num_layers = num_layers
399
+ self.num_heads = num_heads
400
+ self.patch_sizes = patch_sizes
401
+ self.strides = strides
402
+ self.sr_ratios = sr_ratios
403
+ self.with_cp = with_cp
404
+ self.down_sample = down_sample
405
+ assert num_stages == len(num_layers) == len(num_heads) \
406
+ == len(patch_sizes) == len(strides) == len(sr_ratios)
407
+
408
+ self.out_indices = out_indices
409
+ assert max(out_indices) < self.num_stages
410
+
411
+ # transformer encoder
412
+ dpr = [
413
+ x.item()
414
+ for x in torch.linspace(0, drop_path_rate, sum(num_layers))
415
+ ] # stochastic num_layer decay rule
416
+
417
+ cur = 0
418
+ self.layers = ModuleList()
419
+ for i, num_layer in enumerate(num_layers):
420
+ embed_dims_i = embed_dims * num_heads[i]
421
+ patch_embed = PatchEmbed(
422
+ in_channels=in_channels,
423
+ embed_dims=embed_dims_i,
424
+ kernel_size=patch_sizes[i],
425
+ stride=strides[i],
426
+ padding=patch_sizes[i] // 2,
427
+ norm_cfg=norm_cfg)
428
+ layer = ModuleList([
429
+ TransformerEncoderLayer(
430
+ embed_dims=embed_dims_i,
431
+ num_heads=num_heads[i],
432
+ feedforward_channels=mlp_ratio * embed_dims_i,
433
+ drop_rate=drop_rate,
434
+ attn_drop_rate=attn_drop_rate,
435
+ drop_path_rate=dpr[cur + idx],
436
+ qkv_bias=qkv_bias,
437
+ tome_cfg=tome_cfg[i],
438
+ act_cfg=act_cfg,
439
+ norm_cfg=norm_cfg,
440
+ with_cp=with_cp,
441
+ sr_ratio=sr_ratios[i]) for idx in range(num_layer)
442
+ ])
443
+ in_channels = embed_dims_i
444
+ # The ret[0] of build_norm_layer is norm name.
445
+ norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
446
+ self.layers.append(ModuleList([patch_embed, layer, norm]))
447
+ cur += num_layer
448
+
449
+ def init_weights(self):
450
+ if self.init_cfg is None:
451
+ for m in self.modules():
452
+ if isinstance(m, nn.Linear):
453
+ trunc_normal_init(m, std=.02, bias=0.)
454
+ elif isinstance(m, nn.LayerNorm):
455
+ constant_init(m, val=1.0, bias=0.)
456
+ elif isinstance(m, nn.Conv2d):
457
+ fan_out = m.kernel_size[0] * m.kernel_size[
458
+ 1] * m.out_channels
459
+ fan_out //= m.groups
460
+ normal_init(
461
+ m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
462
+ else:
463
+ super().init_weights()
464
+
465
+ def forward(self, x):
466
+ if self.down_sample:
467
+ x = torch.nn.functional.interpolate(x, scale_factor=(0.5, 0.5))
468
+ outs = []
469
+
470
+ for i, layer in enumerate(self.layers):
471
+ x, hw_shape = layer[0](x)
472
+ for block in layer[1]:
473
+ x = block(x, hw_shape)
474
+ x = layer[2](x)
475
+ x = nlc_to_nchw(x, hw_shape)
476
+ if i in self.out_indices:
477
+ outs.append(x)
478
+
479
+ return outs
build/lib/segformer_plusplus/model/head/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .segformer_head import SegformerHead
2
+
3
+ __all__ = ['SegformerHead']
build/lib/segformer_plusplus/model/head/segformer_head.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.cnn import ConvModule
5
+ from mmengine.model import BaseModule
6
+
7
+ from ...utils import MODELS
8
+ from ...utils import resize
9
+
10
+
11
+ @MODELS.register_module()
12
+ class SegformerHead(BaseModule):
13
+ """The all mlp Head of segformer.
14
+
15
+ This head is the implementation of
16
+ `Segformer <https://arxiv.org/abs/2105.15203>` _.
17
+
18
+ Args:
19
+ interpolate_mode: The interpolate mode of MLP head upsample operation.
20
+ Default: 'bilinear'.
21
+ """
22
+
23
+ def __init__(self,
24
+ in_channels=[32, 64, 160, 256],
25
+ in_index=[0, 1, 2, 3],
26
+ channels=256,
27
+ dropout_ratio=0.1,
28
+ out_channels=19,
29
+ norm_cfg=None,
30
+ align_corners=False,
31
+ interpolate_mode='bilinear'):
32
+ super().__init__()
33
+
34
+ self.in_channels = in_channels
35
+ self.in_index = in_index
36
+ self.channels = channels
37
+ self.dropout_ratio = dropout_ratio
38
+ self.out_channels = out_channels
39
+ self.norm_cfg = norm_cfg
40
+ self.align_corners = align_corners
41
+ self.interpolate_mode = interpolate_mode
42
+
43
+ self.act_cfg = dict(type='ReLU')
44
+ self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
45
+ if dropout_ratio > 0:
46
+ self.dropout = nn.Dropout2d(dropout_ratio)
47
+ else:
48
+ self.dropout = None
49
+
50
+ num_inputs = len(self.in_channels)
51
+
52
+ assert num_inputs == len(self.in_index)
53
+
54
+ self.convs = nn.ModuleList()
55
+ for i in range(num_inputs):
56
+ self.convs.append(
57
+ ConvModule(
58
+ in_channels=self.in_channels[i],
59
+ out_channels=self.channels,
60
+ kernel_size=1,
61
+ stride=1,
62
+ norm_cfg=self.norm_cfg,
63
+ act_cfg=self.act_cfg))
64
+
65
+ self.fusion_conv = ConvModule(
66
+ in_channels=self.channels * num_inputs,
67
+ out_channels=self.channels,
68
+ kernel_size=1,
69
+ norm_cfg=self.norm_cfg)
70
+
71
+ def cls_seg(self, feat):
72
+ """Classify each pixel."""
73
+ if self.dropout is not None:
74
+ feat = self.dropout(feat)
75
+ output = self.conv_seg(feat)
76
+ return output
77
+
78
+ def forward(self, inputs):
79
+ # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
80
+ outs = []
81
+ for idx in range(len(inputs)):
82
+ x = inputs[idx]
83
+ conv = self.convs[idx]
84
+ outs.append(
85
+ resize(
86
+ input=conv(x),
87
+ size=inputs[0].shape[2:],
88
+ mode=self.interpolate_mode,
89
+ align_corners=self.align_corners))
90
+
91
+ out = self.fusion_conv(torch.cat(outs, dim=1))
92
+
93
+ out = self.cls_seg(out)
94
+
95
+ return out
build/lib/segformer_plusplus/random_benchmark.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .utils import benchmark
7
+
8
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
9
+
10
+
11
+ def random_benchmark(
12
+ model: torch.nn.Module,
13
+ batch_size: Union[int, List[int]] = 1,
14
+ image_size: Union[Tuple[int], List[Tuple[int]]] = (3, 1024, 1024),
15
+ ):
16
+ """
17
+ Calculate the FPS of a given model using randomly generated tensors.
18
+
19
+ Args:
20
+ model: instance of a model (e.g. SegFormer)
21
+ batch_size: the batch size(s) at which to calculate the FPS (e.g. 1 or [1, 2, 4])
22
+ image_size: the size of the images to use (e.g. (3, 1024, 1024))
23
+
24
+ Returns: the FPS values calculated for all image sizes and batch sizes in the form of a dictionary
25
+
26
+ """
27
+ if isinstance(batch_size, int):
28
+ batch_size = [batch_size]
29
+ if isinstance(image_size, tuple):
30
+ image_size = [image_size]
31
+
32
+ values = {}
33
+ throughput_values = []
34
+
35
+ for i in image_size:
36
+ # fill with fps for each batch size
37
+ fps = []
38
+ for b in batch_size:
39
+ for _ in range(4):
40
+ # Baseline benchmark
41
+ if i[1] >= 1024:
42
+ r = 16
43
+ else:
44
+ r = 32
45
+ baseline_throughput = benchmark(
46
+ model.to(device),
47
+ device=device,
48
+ verbose=True,
49
+ runs=r,
50
+ batch_size=b,
51
+ input_size=i
52
+ )
53
+ throughput_values.append(baseline_throughput)
54
+ throughput_values = np.asarray(throughput_values)
55
+ throughput = np.around(np.mean(throughput_values), decimals=2)
56
+ print('Im_size:', i, 'Batch_size:', b, 'Mean:', throughput, 'Std:',
57
+ np.around(np.std(throughput_values), decimals=2))
58
+ throughput_values = []
59
+ fps.append({b: throughput})
60
+ values[i] = fps
61
+ return values
build/lib/segformer_plusplus/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .embed import PatchEmbed
3
+ from .shape_convert import nchw_to_nlc, nlc_to_nchw
4
+ from .wrappers import resize
5
+ from .tome_presets import tome_presets
6
+ from .registry import MODELS
7
+ from .imagenet_weights import imagenet_weights
8
+ from .benchmark import benchmark
9
+
10
+ __all__ = [
11
+ 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'resize', 'tome_presets', 'MODELS', 'imagenet_weights', 'benchmark'
12
+ ]
build/lib/segformer_plusplus/utils/benchmark.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # Source: https://github.com/facebookresearch/ToMe/blob/main/tome/utils.py
5
+ # --------------------------------------------------------
6
+
7
+ import time
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+
14
+ def benchmark(
15
+ model: torch.nn.Module,
16
+ device: torch.device = 0,
17
+ input_size: Tuple[int] = (3, 224, 224),
18
+ batch_size: int = 64,
19
+ runs: int = 40,
20
+ throw_out: float = 0.25,
21
+ use_fp16: bool = False,
22
+ verbose: bool = False,
23
+ ) -> float:
24
+ """
25
+ Benchmark the given model with random inputs at the given batch size.
26
+
27
+ Args:
28
+ - model: the module to benchmark
29
+ - device: the device to use for benchmarking
30
+ - input_size: the input size to pass to the model (channels, h, w)
31
+ - batch_size: the batch size to use for evaluation
32
+ - runs: the number of total runs to do
33
+ - throw_out: the percentage of runs to throw out at the start of testing
34
+ - use_fp16: whether or not to benchmark with float16 and autocast
35
+ - verbose: whether or not to use tqdm to print progress / print throughput at end
36
+
37
+ Returns:
38
+ - the throughput measured in images / second
39
+ """
40
+ if not isinstance(device, torch.device):
41
+ device = torch.device(device)
42
+ is_cuda = torch.device(device).type == "cuda"
43
+
44
+ model = model.eval().to(device)
45
+ input = torch.rand(batch_size, *input_size, device=device)
46
+ if use_fp16:
47
+ input = input.half()
48
+
49
+ warm_up = int(runs * throw_out)
50
+ total = 0
51
+ start = time.time()
52
+
53
+ with torch.autocast(device.type, enabled=use_fp16):
54
+ with torch.no_grad():
55
+ for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
56
+ if i == warm_up:
57
+ if is_cuda:
58
+ torch.cuda.synchronize()
59
+ total = 0
60
+ start = time.time()
61
+
62
+ model(input)
63
+ total += batch_size
64
+
65
+ if is_cuda:
66
+ torch.cuda.synchronize()
67
+
68
+ end = time.time()
69
+ elapsed = end - start
70
+
71
+ throughput = total / elapsed
72
+
73
+ if verbose:
74
+ print(f"Throughput: {throughput:.2f} im/s")
75
+
76
+ return throughput
build/lib/segformer_plusplus/utils/embed.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import Sequence
4
+
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from mmcv.cnn import build_conv_layer, build_norm_layer
8
+ from mmengine.model import BaseModule
9
+ from mmengine.utils import to_2tuple
10
+
11
+
12
+ class AdaptivePadding(nn.Module):
13
+ """Applies padding to input (if needed) so that input can get fully covered
14
+ by filter you specified. It supports two modes "same" and "corner". The
15
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
16
+ input. The "corner" mode would pad zero to bottom right.
17
+
18
+ Args:
19
+ kernel_size (int | tuple): Size of the kernel:
20
+ stride (int | tuple): Stride of the filter. Default: 1:
21
+ dilation (int | tuple): Spacing between kernel elements.
22
+ Default: 1.
23
+ padding (str): Support "same" and "corner", "corner" mode
24
+ would pad zero to bottom right, and "same" mode would
25
+ pad zero around input. Default: "corner".
26
+ Example:
27
+ >>> kernel_size = 16
28
+ >>> stride = 16
29
+ >>> dilation = 1
30
+ >>> input = torch.rand(1, 1, 15, 17)
31
+ >>> adap_pad = AdaptivePadding(
32
+ >>> kernel_size=kernel_size,
33
+ >>> stride=stride,
34
+ >>> dilation=dilation,
35
+ >>> padding="corner")
36
+ >>> out = adap_pad(input)
37
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
38
+ >>> input = torch.rand(1, 1, 16, 17)
39
+ >>> out = adap_pad(input)
40
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
41
+ """
42
+
43
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
44
+
45
+ super().__init__()
46
+
47
+ assert padding in ('same', 'corner')
48
+
49
+ kernel_size = to_2tuple(kernel_size)
50
+ stride = to_2tuple(stride)
51
+ dilation = to_2tuple(dilation)
52
+
53
+ self.padding = padding
54
+ self.kernel_size = kernel_size
55
+ self.stride = stride
56
+ self.dilation = dilation
57
+
58
+ def get_pad_shape(self, input_shape):
59
+ input_h, input_w = input_shape
60
+ kernel_h, kernel_w = self.kernel_size
61
+ stride_h, stride_w = self.stride
62
+ output_h = math.ceil(input_h / stride_h)
63
+ output_w = math.ceil(input_w / stride_w)
64
+ pad_h = max((output_h - 1) * stride_h +
65
+ (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
66
+ pad_w = max((output_w - 1) * stride_w +
67
+ (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
68
+ return pad_h, pad_w
69
+
70
+ def forward(self, x):
71
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
72
+ if pad_h > 0 or pad_w > 0:
73
+ if self.padding == 'corner':
74
+ x = F.pad(x, [0, pad_w, 0, pad_h])
75
+ elif self.padding == 'same':
76
+ x = F.pad(x, [
77
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
78
+ pad_h - pad_h // 2
79
+ ])
80
+ return x
81
+
82
+
83
+ class PatchEmbed(BaseModule):
84
+ """Image to Patch Embedding.
85
+
86
+ We use a conv layer to implement PatchEmbed.
87
+
88
+ Args:
89
+ in_channels (int): The num of input channels. Default: 3
90
+ embed_dims (int): The dimensions of embedding. Default: 768
91
+ conv_type (str): The config dict for embedding
92
+ conv layer type selection. Default: "Conv2d".
93
+ kernel_size (int): The kernel_size of embedding conv. Default: 16.
94
+ stride (int, optional): The slide stride of embedding conv.
95
+ Default: None (Would be set as `kernel_size`).
96
+ padding (int | tuple | string ): The padding length of
97
+ embedding conv. When it is a string, it means the mode
98
+ of adaptive padding, support "same" and "corner" now.
99
+ Default: "corner".
100
+ dilation (int): The dilation rate of embedding conv. Default: 1.
101
+ bias (bool): Bias of embed conv. Default: True.
102
+ norm_cfg (dict, optional): Config dict for normalization layer.
103
+ Default: None.
104
+ input_size (int | tuple | None): The size of input, which will be
105
+ used to calculate the out size. Only work when `dynamic_size`
106
+ is False. Default: None.
107
+ init_cfg (`mmengine.ConfigDict`, optional): The Config for
108
+ initialization. Default: None.
109
+ """
110
+
111
+ def __init__(self,
112
+ in_channels=3,
113
+ embed_dims=768,
114
+ conv_type='Conv2d',
115
+ kernel_size=16,
116
+ stride=None,
117
+ padding='corner',
118
+ dilation=1,
119
+ bias=True,
120
+ norm_cfg=None,
121
+ input_size=None,
122
+ init_cfg=None):
123
+ super().__init__(init_cfg=init_cfg)
124
+
125
+ self.embed_dims = embed_dims
126
+ if stride is None:
127
+ stride = kernel_size
128
+
129
+ kernel_size = to_2tuple(kernel_size)
130
+ stride = to_2tuple(stride)
131
+ dilation = to_2tuple(dilation)
132
+
133
+ if isinstance(padding, str):
134
+ self.adap_padding = AdaptivePadding(
135
+ kernel_size=kernel_size,
136
+ stride=stride,
137
+ dilation=dilation,
138
+ padding=padding)
139
+ # disable the padding of conv
140
+ padding = 0
141
+ else:
142
+ self.adap_padding = None
143
+ padding = to_2tuple(padding)
144
+
145
+ self.projection = build_conv_layer(
146
+ dict(type=conv_type),
147
+ in_channels=in_channels,
148
+ out_channels=embed_dims,
149
+ kernel_size=kernel_size,
150
+ stride=stride,
151
+ padding=padding,
152
+ dilation=dilation,
153
+ bias=bias)
154
+
155
+ if norm_cfg is not None:
156
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
157
+ else:
158
+ self.norm = None
159
+
160
+ if input_size:
161
+ input_size = to_2tuple(input_size)
162
+ # `init_out_size` would be used outside to
163
+ # calculate the num_patches
164
+ # when `use_abs_pos_embed` outside
165
+ self.init_input_size = input_size
166
+ if self.adap_padding:
167
+ pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
168
+ input_h, input_w = input_size
169
+ input_h = input_h + pad_h
170
+ input_w = input_w + pad_w
171
+ input_size = (input_h, input_w)
172
+
173
+ # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
174
+ h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
175
+ (kernel_size[0] - 1) - 1) // stride[0] + 1
176
+ w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
177
+ (kernel_size[1] - 1) - 1) // stride[1] + 1
178
+ self.init_out_size = (h_out, w_out)
179
+ else:
180
+ self.init_input_size = None
181
+ self.init_out_size = None
182
+
183
+ def forward(self, x):
184
+ """
185
+ Args:
186
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
187
+
188
+ Returns:
189
+ tuple: Contains merged results and its spatial shape.
190
+
191
+ - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
192
+ - out_size (tuple[int]): Spatial shape of x, arrange as
193
+ (out_h, out_w).
194
+ """
195
+
196
+ if self.adap_padding:
197
+ x = self.adap_padding(x)
198
+
199
+ x = self.projection(x)
200
+ out_size = (x.shape[2], x.shape[3])
201
+ x = x.flatten(2).transpose(1, 2)
202
+ if self.norm is not None:
203
+ x = self.norm(x)
204
+ return x, out_size
205
+
206
+
207
+ class PatchMerging(BaseModule):
208
+ """Merge patch feature map.
209
+
210
+ This layer groups feature map by kernel_size, and applies norm and linear
211
+ layers to the grouped feature map. Our implementation uses `nn.Unfold` to
212
+ merge patch, which is about 25% faster than original implementation.
213
+ Instead, we need to modify pretrained models for compatibility.
214
+
215
+ Args:
216
+ in_channels (int): The num of input channels.
217
+ out_channels (int): The num of output channels.
218
+ kernel_size (int | tuple, optional): the kernel size in the unfold
219
+ layer. Defaults to 2.
220
+ stride (int | tuple, optional): the stride of the sliding blocks in the
221
+ unfold layer. Default: None. (Would be set as `kernel_size`)
222
+ padding (int | tuple | string ): The padding length of
223
+ embedding conv. When it is a string, it means the mode
224
+ of adaptive padding, support "same" and "corner" now.
225
+ Default: "corner".
226
+ dilation (int | tuple, optional): dilation parameter in the unfold
227
+ layer. Default: 1.
228
+ bias (bool, optional): Whether to add bias in linear layer or not.
229
+ Defaults: False.
230
+ norm_cfg (dict, optional): Config dict for normalization layer.
231
+ Default: dict(type='LN').
232
+ init_cfg (dict, optional): The extra config for initialization.
233
+ Default: None.
234
+ """
235
+
236
+ def __init__(self,
237
+ in_channels,
238
+ out_channels,
239
+ kernel_size=2,
240
+ stride=None,
241
+ padding='corner',
242
+ dilation=1,
243
+ bias=False,
244
+ norm_cfg=dict(type='LN'),
245
+ init_cfg=None):
246
+ super().__init__(init_cfg=init_cfg)
247
+ self.in_channels = in_channels
248
+ self.out_channels = out_channels
249
+ if stride:
250
+ stride = stride
251
+ else:
252
+ stride = kernel_size
253
+
254
+ kernel_size = to_2tuple(kernel_size)
255
+ stride = to_2tuple(stride)
256
+ dilation = to_2tuple(dilation)
257
+
258
+ if isinstance(padding, str):
259
+ self.adap_padding = AdaptivePadding(
260
+ kernel_size=kernel_size,
261
+ stride=stride,
262
+ dilation=dilation,
263
+ padding=padding)
264
+ # disable the padding of unfold
265
+ padding = 0
266
+ else:
267
+ self.adap_padding = None
268
+
269
+ padding = to_2tuple(padding)
270
+ self.sampler = nn.Unfold(
271
+ kernel_size=kernel_size,
272
+ dilation=dilation,
273
+ padding=padding,
274
+ stride=stride)
275
+
276
+ sample_dim = kernel_size[0] * kernel_size[1] * in_channels
277
+
278
+ if norm_cfg is not None:
279
+ self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
280
+ else:
281
+ self.norm = None
282
+
283
+ self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
284
+
285
+ def forward(self, x, input_size):
286
+ """
287
+ Args:
288
+ x (Tensor): Has shape (B, H*W, C_in).
289
+ input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
290
+ Default: None.
291
+
292
+ Returns:
293
+ tuple: Contains merged results and its spatial shape.
294
+
295
+ - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
296
+ - out_size (tuple[int]): Spatial shape of x, arrange as
297
+ (Merged_H, Merged_W).
298
+ """
299
+ B, L, C = x.shape
300
+ assert isinstance(input_size, Sequence), f'Expect ' \
301
+ f'input_size is ' \
302
+ f'`Sequence` ' \
303
+ f'but get {input_size}'
304
+
305
+ H, W = input_size
306
+ assert L == H * W, 'input feature has wrong size'
307
+
308
+ x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
309
+ # Use nn.Unfold to merge patch. About 25% faster than original method,
310
+ # but need to modify pretrained model for compatibility
311
+
312
+ if self.adap_padding:
313
+ x = self.adap_padding(x)
314
+ H, W = x.shape[-2:]
315
+
316
+ x = self.sampler(x)
317
+ # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
318
+
319
+ out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
320
+ (self.sampler.kernel_size[0] - 1) -
321
+ 1) // self.sampler.stride[0] + 1
322
+ out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
323
+ (self.sampler.kernel_size[1] - 1) -
324
+ 1) // self.sampler.stride[1] + 1
325
+
326
+ output_size = (out_h, out_w)
327
+ x = x.transpose(1, 2) # B, H/2*W/2, 4*C
328
+ x = self.norm(x) if self.norm else x
329
+ x = self.reduction(x)
330
+ return x, output_size
build/lib/segformer_plusplus/utils/imagenet_weights.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ imagenet_weights = {
2
+ 'b0': 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth',
3
+ 'b1': 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b1_20220624-02e5a6a1.pth',
4
+ 'b2': 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b2_20220624-66e8bf70.pth',
5
+ 'b3': 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b3_20220624-13b1141c.pth',
6
+ 'b4': 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b4_20220624-d588d980.pth',
7
+ 'b5': 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b5_20220624-658746d9.pth'
8
+ }
build/lib/segformer_plusplus/utils/registry.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from mmengine import Registry
2
+
3
+ MODELS = Registry(
4
+ 'models',
5
+ locations=['segformer_plusplus.model.backbone', 'segformer_plusplus.model.head']
6
+ )
build/lib/segformer_plusplus/utils/shape_convert.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ def nlc_to_nchw(x, hw_shape):
3
+ """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
4
+
5
+ Args:
6
+ x (Tensor): The input tensor of shape [N, L, C] before conversion.
7
+ hw_shape (Sequence[int]): The height and width of output feature map.
8
+
9
+ Returns:
10
+ Tensor: The output tensor of shape [N, C, H, W] after conversion.
11
+ """
12
+ H, W = hw_shape
13
+ assert len(x.shape) == 3
14
+ B, L, C = x.shape
15
+ assert L == H * W, 'The seq_len doesn\'t match H, W'
16
+ return x.transpose(1, 2).reshape(B, C, H, W)
17
+
18
+
19
+ def nchw_to_nlc(x):
20
+ """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
21
+
22
+ Args:
23
+ x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
24
+
25
+ Returns:
26
+ Tensor: The output tensor of shape [N, L, C] after conversion.
27
+ """
28
+ assert len(x.shape) == 4
29
+ return x.flatten(2).transpose(1, 2).contiguous()
30
+
31
+
32
+ def nchw2nlc2nchw(module, x, contiguous=False, **kwargs):
33
+ """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the
34
+ reshaped tensor as the input of `module`, and the convert the output of
35
+ `module`, whose shape is.
36
+
37
+ [N, L, C], to [N, C, H, W].
38
+
39
+ Args:
40
+ module (Callable): A callable object the takes a tensor
41
+ with shape [N, L, C] as input.
42
+ x (Tensor): The input tensor of shape [N, C, H, W].
43
+ contiguous:
44
+ contiguous (Bool): Whether to make the tensor contiguous
45
+ after each shape transform.
46
+
47
+ Returns:
48
+ Tensor: The output tensor of shape [N, C, H, W].
49
+
50
+ Example:
51
+ >>> import torch
52
+ >>> import torch.nn as nn
53
+ >>> norm = nn.LayerNorm(4)
54
+ >>> feature_map = torch.rand(4, 4, 5, 5)
55
+ >>> output = nchw2nlc2nchw(norm, feature_map)
56
+ """
57
+ B, C, H, W = x.shape
58
+ if not contiguous:
59
+ x = x.flatten(2).transpose(1, 2)
60
+ x = module(x, **kwargs)
61
+ x = x.transpose(1, 2).reshape(B, C, H, W)
62
+ else:
63
+ x = x.flatten(2).transpose(1, 2).contiguous()
64
+ x = module(x, **kwargs)
65
+ x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
66
+ return x
67
+
68
+
69
+ def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):
70
+ """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the
71
+ reshaped tensor as the input of `module`, and convert the output of
72
+ `module`, whose shape is.
73
+
74
+ [N, C, H, W], to [N, L, C].
75
+
76
+ Args:
77
+ module (Callable): A callable object the takes a tensor
78
+ with shape [N, C, H, W] as input.
79
+ x (Tensor): The input tensor of shape [N, L, C].
80
+ hw_shape: (Sequence[int]): The height and width of the
81
+ feature map with shape [N, C, H, W].
82
+ contiguous (Bool): Whether to make the tensor contiguous
83
+ after each shape transform.
84
+
85
+ Returns:
86
+ Tensor: The output tensor of shape [N, L, C].
87
+
88
+ Example:
89
+ >>> import torch
90
+ >>> import torch.nn as nn
91
+ >>> conv = nn.Conv2d(16, 16, 3, 1, 1)
92
+ >>> feature_map = torch.rand(4, 25, 16)
93
+ >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))
94
+ """
95
+ H, W = hw_shape
96
+ assert len(x.shape) == 3
97
+ B, L, C = x.shape
98
+ assert L == H * W, 'The seq_len doesn\'t match H, W'
99
+ if not contiguous:
100
+ x = x.transpose(1, 2).reshape(B, C, H, W)
101
+ x = module(x, **kwargs)
102
+ x = x.flatten(2).transpose(1, 2)
103
+ else:
104
+ x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
105
+ x = module(x, **kwargs)
106
+ x = x.flatten(2).transpose(1, 2).contiguous()
107
+ return x
build/lib/segformer_plusplus/utils/tome_presets.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tome_presets = {
2
+ 'bsm_hq': [
3
+ dict(q_mode=None, kv_mode='bsm', kv_r=0.6, kv_sx=2, kv_sy=2),
4
+ dict(q_mode=None, kv_mode='bsm', kv_r=0.6, kv_sx=2, kv_sy=2),
5
+ dict(q_mode='bsm', kv_mode=None, q_r=0.8, q_sx=4, q_sy=4),
6
+ dict(q_mode='bsm', kv_mode=None, q_r=0.8, q_sx=4, q_sy=4)
7
+ ],
8
+ 'bsm_fast': [
9
+ dict(q_mode=None, kv_mode='bsm_r2D', kv_r=0.9, kv_sx=4, kv_sy=4),
10
+ dict(q_mode=None, kv_mode='bsm_r2D', kv_r=0.9, kv_sx=4, kv_sy=4),
11
+ dict(q_mode='bsm_r2D', kv_mode=None, q_r=0.9, q_sx=4, q_sy=4),
12
+ dict(q_mode='bsm_r2D', kv_mode=None, q_r=0.9, q_sx=4, q_sy=4)
13
+ ],
14
+ 'n2d_2x2': [
15
+ dict(q_mode='neighbor_2D', kv_mode=None, q_s=(2, 2)),
16
+ dict(q_mode='neighbor_2D', kv_mode=None, q_s=(2, 2)),
17
+ dict(q_mode='neighbor_2D', kv_mode=None, q_s=(2, 2)),
18
+ dict(q_mode='neighbor_2D', kv_mode=None, q_s=(2, 2))
19
+ ]
20
+ }
build/lib/segformer_plusplus/utils/wrappers.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def resize(input,
9
+ size=None,
10
+ scale_factor=None,
11
+ mode='nearest',
12
+ align_corners=None,
13
+ warning=True):
14
+ if warning:
15
+ if size is not None and align_corners:
16
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
17
+ output_h, output_w = tuple(int(x) for x in size)
18
+ if output_h > input_h or output_w > output_h:
19
+ if ((output_h > 1 and output_w > 1 and input_h > 1
20
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
21
+ and (output_w - 1) % (input_w - 1)):
22
+ warnings.warn(
23
+ f'When align_corners={align_corners}, '
24
+ 'the output would more aligned if '
25
+ f'input size {(input_h, input_w)} is `x+1` and '
26
+ f'out size {(output_h, output_w)} is `nx+1`')
27
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
28
+
29
+
30
+ class Upsample(nn.Module):
31
+
32
+ def __init__(self,
33
+ size=None,
34
+ scale_factor=None,
35
+ mode='nearest',
36
+ align_corners=None):
37
+ super().__init__()
38
+ self.size = size
39
+ if isinstance(scale_factor, tuple):
40
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
41
+ else:
42
+ self.scale_factor = float(scale_factor) if scale_factor else None
43
+ self.mode = mode
44
+ self.align_corners = align_corners
45
+
46
+ def forward(self, x):
47
+ if not self.size:
48
+ size = [int(t * self.scale_factor) for t in x.shape[-2:]]
49
+ else:
50
+ size = self.size
51
+ return resize(x, size, None, self.mode, self.align_corners)
cityscapes_prediction_output_reference.txt ADDED
The diff for this file is too large to render. See raw diff
 
segformer_plusplus.egg-info/PKG-INFO ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: segformer-plusplus
3
+ Version: 0.2
4
+ Summary: Segformer++: Efficient Token-Merging Strategies for High-Resolution Semantic Segmentation
5
+ Home-page: UNKNOWN
6
+ Author: Marco Kantonis
7
+ License: MIT
8
+ Platform: UNKNOWN
9
+
10
+ https://arxiv.org/abs/2405.14467
11
+
segformer_plusplus.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ setup.py
2
+ segformer_plusplus/__init__.py
3
+ segformer_plusplus/build_model.py
4
+ segformer_plusplus/random_benchmark.py
5
+ segformer_plusplus.egg-info/PKG-INFO
6
+ segformer_plusplus.egg-info/SOURCES.txt
7
+ segformer_plusplus.egg-info/dependency_links.txt
8
+ segformer_plusplus.egg-info/requires.txt
9
+ segformer_plusplus.egg-info/top_level.txt
10
+ segformer_plusplus/configs/__init__.py
11
+ segformer_plusplus/configs/segformer_mit_b0.py
12
+ segformer_plusplus/configs/segformer_mit_b1.py
13
+ segformer_plusplus/configs/segformer_mit_b2.py
14
+ segformer_plusplus/configs/segformer_mit_b3.py
15
+ segformer_plusplus/configs/segformer_mit_b4.py
16
+ segformer_plusplus/configs/segformer_mit_b5.py
17
+ segformer_plusplus/model/__init__.py
18
+ segformer_plusplus/model/backbone/__init__.py
19
+ segformer_plusplus/model/backbone/mit.py
20
+ segformer_plusplus/model/head/__init__.py
21
+ segformer_plusplus/model/head/segformer_head.py
22
+ segformer_plusplus/utils/__init__.py
23
+ segformer_plusplus/utils/benchmark.py
24
+ segformer_plusplus/utils/embed.py
25
+ segformer_plusplus/utils/imagenet_weights.py
26
+ segformer_plusplus/utils/registry.py
27
+ segformer_plusplus/utils/shape_convert.py
28
+ segformer_plusplus/utils/tome_presets.py
29
+ segformer_plusplus/utils/wrappers.py
segformer_plusplus.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
segformer_plusplus.egg-info/requires.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tomesd
2
+ torch>=2.0.1
segformer_plusplus.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ segformer_plusplus
segformer_plusplus/Registry/default_scope.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import time
4
+ from contextlib import contextmanager
5
+ from typing import Generator, Optional
6
+
7
+ from ..utils.manager import ManagerMixin, _accquire_lock, _release_lock
8
+
9
+
10
+ class DefaultScope(ManagerMixin):
11
+ """Scope of current task used to reset the current registry, which can be
12
+ accessed globally.
13
+
14
+ Consider the case of resetting the current ``Registry`` by
15
+ ``default_scope`` in the internal module which cannot access runner
16
+ directly, it is difficult to get the ``default_scope`` defined in
17
+ ``Runner``. However, if ``Runner`` created ``DefaultScope`` instance
18
+ by given ``default_scope``, the internal module can get
19
+ ``default_scope`` by ``DefaultScope.get_current_instance`` everywhere.
20
+
21
+ Args:
22
+ name (str): Name of default scope for global access.
23
+ scope_name (str): Scope of current task.
24
+
25
+ Examples:
26
+ >>> from mmengine.model import MODELS
27
+ >>> # Define default scope in runner.
28
+ >>> DefaultScope.get_instance('task', scope_name='mmdet')
29
+ >>> # Get default scope globally.
30
+ >>> scope_name = DefaultScope.get_instance('task').scope_name
31
+ """
32
+
33
+ def __init__(self, name: str, scope_name: str):
34
+ super().__init__(name)
35
+ assert isinstance(
36
+ scope_name,
37
+ str), (f'scope_name should be a string, but got {scope_name}')
38
+ self._scope_name = scope_name
39
+
40
+ @property
41
+ def scope_name(self) -> str:
42
+ """
43
+ Returns:
44
+ str: Get current scope.
45
+ """
46
+ return self._scope_name
47
+
48
+ @classmethod
49
+ def get_current_instance(cls) -> Optional['DefaultScope']:
50
+ """Get latest created default scope.
51
+
52
+ Since default_scope is an optional argument for ``Registry.build``.
53
+ ``get_current_instance`` should return ``None`` if there is no
54
+ ``DefaultScope`` created.
55
+
56
+ Examples:
57
+ >>> default_scope = DefaultScope.get_current_instance()
58
+ >>> # There is no `DefaultScope` created yet,
59
+ >>> # `get_current_instance` return `None`.
60
+ >>> default_scope = DefaultScope.get_instance(
61
+ >>> 'instance_name', scope_name='mmengine')
62
+ >>> default_scope.scope_name
63
+ mmengine
64
+ >>> default_scope = DefaultScope.get_current_instance()
65
+ >>> default_scope.scope_name
66
+ mmengine
67
+
68
+ Returns:
69
+ Optional[DefaultScope]: Return None If there has not been
70
+ ``DefaultScope`` instance created yet, otherwise return the
71
+ latest created DefaultScope instance.
72
+ """
73
+ _accquire_lock()
74
+ if cls._instance_dict:
75
+ instance = super().get_current_instance()
76
+ else:
77
+ instance = None
78
+ _release_lock()
79
+ return instance
80
+
81
+ @classmethod
82
+ @contextmanager
83
+ def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
84
+ """Overwrite the current default scope with `scope_name`"""
85
+ if scope_name is None:
86
+ yield
87
+ else:
88
+ tmp = copy.deepcopy(cls._instance_dict)
89
+ # To avoid create an instance with the same name.
90
+ time.sleep(1e-6)
91
+ cls.get_instance(f'overwrite-{time.time()}', scope_name=scope_name)
92
+ try:
93
+ yield
94
+ finally:
95
+ cls._instance_dict = tmp
segformer_plusplus/Registry/registry.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import inspect
3
+ import sys
4
+ import types
5
+ from collections import abc
6
+ from collections.abc import Callable
7
+ from contextlib import contextmanager
8
+ from importlib import import_module
9
+ from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union
10
+ from rich.console import Console
11
+ from rich.table import Table
12
+
13
+ from .default_scope import DefaultScope
14
+
15
+
16
+ MODULE2PACKAGE = {
17
+ 'mmcls': 'mmcls',
18
+ 'mmdet': 'mmdet',
19
+ 'mmdet3d': 'mmdet3d',
20
+ 'mmseg': 'mmsegmentation',
21
+ 'mmaction': 'mmaction2',
22
+ 'mmtrack': 'mmtrack',
23
+ 'mmpose': 'mmpose',
24
+ 'mmedit': 'mmedit',
25
+ 'mmocr': 'mmocr',
26
+ 'mmgen': 'mmgen',
27
+ 'mmfewshot': 'mmfewshot',
28
+ 'mmrazor': 'mmrazor',
29
+ 'mmflow': 'mmflow',
30
+ 'mmhuman3d': 'mmhuman3d',
31
+ 'mmrotate': 'mmrotate',
32
+ 'mmselfsup': 'mmselfsup',
33
+ 'mmyolo': 'mmyolo',
34
+ 'mmpretrain': 'mmpretrain',
35
+ 'mmagic': 'mmagic',
36
+ }
37
+
38
+ class Registry:
39
+ """A registry to map strings to classes or functions.
40
+
41
+ Registered object could be built from registry. Meanwhile, registered
42
+ functions could be called from registry.
43
+
44
+ Args:
45
+ name (str): Registry name.
46
+ build_func (callable, optional): A function to construct instance
47
+ from Registry. :func:`build_from_cfg` is used if neither ``parent``
48
+ or ``build_func`` is specified. If ``parent`` is specified and
49
+ ``build_func`` is not given, ``build_func`` will be inherited
50
+ from ``parent``. Defaults to None.
51
+ parent (:obj:`Registry`, optional): Parent registry. The class
52
+ registered in children registry could be built from parent.
53
+ Defaults to None.
54
+ scope (str, optional): The scope of registry. It is the key to search
55
+ for children registry. If not specified, scope will be the name of
56
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
57
+ Defaults to None.
58
+ locations (list): The locations to import the modules registered
59
+ in this registry. Defaults to [].
60
+ New in version 0.4.0.
61
+
62
+ Examples:
63
+ >>> # define a registry
64
+ >>> MODELS = Registry('models')
65
+ >>> # registry the `ResNet` to `MODELS`
66
+ >>> @MODELS.register_module()
67
+ >>> class ResNet:
68
+ >>> pass
69
+ >>> # build model from `MODELS`
70
+ >>> resnet = MODELS.build(dict(type='ResNet'))
71
+ >>> @MODELS.register_module()
72
+ >>> def resnet50():
73
+ >>> pass
74
+ >>> resnet = MODELS.build(dict(type='resnet50'))
75
+
76
+ >>> # hierarchical registry
77
+ >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
78
+ >>> @DETECTORS.register_module()
79
+ >>> class FasterRCNN:
80
+ >>> pass
81
+ >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
82
+
83
+ >>> # add locations to enable auto import
84
+ >>> DETECTORS = Registry('detectors', parent=MODELS,
85
+ >>> scope='det', locations=['det.models.detectors'])
86
+ >>> # define this class in 'det.models.detectors'
87
+ >>> @DETECTORS.register_module()
88
+ >>> class MaskRCNN:
89
+ >>> pass
90
+ >>> # The registry will auto import det.models.detectors.MaskRCNN
91
+ >>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN'))
92
+
93
+ More advanced usages can be found at
94
+ https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
95
+ """
96
+
97
+ def __init__(self,
98
+ name: str,
99
+ build_func: Optional[Callable] = None,
100
+ parent: Optional['Registry'] = None,
101
+ scope: Optional[str] = None,
102
+ locations: List = []):
103
+ self._name = name
104
+ self._module_dict: Dict[str, Type] = dict()
105
+ self._children: Dict[str, 'Registry'] = dict()
106
+ self._locations = locations
107
+ self._imported = False
108
+
109
+ if scope is not None:
110
+ assert isinstance(scope, str)
111
+ self._scope = scope
112
+ else:
113
+ self._scope = self.infer_scope()
114
+
115
+ # See https://mypy.readthedocs.io/en/stable/common_issues.html#
116
+ # variables-vs-type-aliases for the use
117
+ self.parent: Optional['Registry']
118
+ if parent is not None:
119
+ assert isinstance(parent, Registry)
120
+ parent._add_child(self)
121
+ self.parent = parent
122
+ else:
123
+ self.parent = None
124
+
125
+ # self.build_func will be set with the following priority:
126
+ # 1. build_func
127
+ # 2. parent.build_func
128
+ # 3. build_from_cfg
129
+ self.build_func: Callable
130
+ if build_func is None:
131
+ if self.parent is not None:
132
+ self.build_func = self.parent.build_func
133
+ else:
134
+ from ..utils.build_functions import build_from_cfg
135
+ self.build_func = build_from_cfg
136
+ else:
137
+ self.build_func = build_func
138
+
139
+ def __len__(self):
140
+ return len(self._module_dict)
141
+
142
+ def __contains__(self, key):
143
+ return self.get(key) is not None
144
+
145
+ def __repr__(self):
146
+ table = Table(title=f'Registry of {self._name}')
147
+ table.add_column('Names', justify='left', style='cyan')
148
+ table.add_column('Objects', justify='left', style='green')
149
+
150
+ for name, obj in sorted(self._module_dict.items()):
151
+ table.add_row(name, str(obj))
152
+
153
+ console = Console()
154
+ with console.capture() as capture:
155
+ console.print(table, end='')
156
+
157
+ return capture.get()
158
+
159
+ @staticmethod
160
+ def infer_scope() -> str:
161
+ """Infer the scope of registry.
162
+
163
+ The name of the package where registry is defined will be returned.
164
+
165
+ Returns:
166
+ str: The inferred scope name.
167
+
168
+ Examples:
169
+ >>> # in mmdet/models/backbone/resnet.py
170
+ >>> MODELS = Registry('models')
171
+ >>> @MODELS.register_module()
172
+ >>> class ResNet:
173
+ >>> pass
174
+ >>> # The scope of ``ResNet`` will be ``mmdet``.
175
+ """
176
+
177
+ # `sys._getframe` returns the frame object that many calls below the
178
+ # top of the stack. The call stack for `infer_scope` can be listed as
179
+ # follow:
180
+ # frame-0: `infer_scope` itself
181
+ # frame-1: `__init__` of `Registry` which calls the `infer_scope`
182
+ # frame-2: Where the `Registry(...)` is called
183
+ module = inspect.getmodule(sys._getframe(2))
184
+ if module is not None:
185
+ filename = module.__name__
186
+ split_filename = filename.split('.')
187
+ scope = split_filename[0]
188
+ else:
189
+ # use "mmengine" to handle some cases which can not infer the scope
190
+ # like initializing Registry in interactive mode
191
+ scope = 'mmengine'
192
+ return scope
193
+
194
+ @staticmethod
195
+ def split_scope_key(key: str) -> Tuple[Optional[str], str]:
196
+ """Split scope and key.
197
+
198
+ The first scope will be split from key.
199
+
200
+ Return:
201
+ tuple[str | None, str]: The former element is the first scope of
202
+ the key, which can be ``None``. The latter is the remaining key.
203
+
204
+ Examples:
205
+ >>> Registry.split_scope_key('mmdet.ResNet')
206
+ 'mmdet', 'ResNet'
207
+ >>> Registry.split_scope_key('ResNet')
208
+ None, 'ResNet'
209
+ """
210
+ split_index = key.find('.')
211
+ if split_index != -1:
212
+ return key[:split_index], key[split_index + 1:]
213
+ else:
214
+ return None, key
215
+
216
+ @property
217
+ def name(self):
218
+ return self._name
219
+
220
+ @property
221
+ def scope(self):
222
+ return self._scope
223
+
224
+ @property
225
+ def module_dict(self):
226
+ return self._module_dict
227
+
228
+ @property
229
+ def children(self):
230
+ return self._children
231
+
232
+ @property
233
+ def root(self):
234
+ return self._get_root_registry()
235
+
236
+ @contextmanager
237
+ def switch_scope_and_registry(self, scope: Optional[str]) -> Generator:
238
+ """Temporarily switch default scope to the target scope, and get the
239
+ corresponding registry.
240
+
241
+ If the registry of the corresponding scope exists, yield the
242
+ registry, otherwise yield the current itself.
243
+
244
+ Args:
245
+ scope (str, optional): The target scope.
246
+
247
+ Examples:
248
+ >>> from mmengine.registry import Registry, DefaultScope, MODELS
249
+ >>> import time
250
+ >>> # External Registry
251
+ >>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet',
252
+ >>> parent=MODELS)
253
+ >>> MMCLS_MODELS = Registry('mmcls_model', scope='mmcls',
254
+ >>> parent=MODELS)
255
+ >>> # Local Registry
256
+ >>> CUSTOM_MODELS = Registry('custom_model', scope='custom',
257
+ >>> parent=MODELS)
258
+ >>>
259
+ >>> # Initiate DefaultScope
260
+ >>> DefaultScope.get_instance(f'scope_{time.time()}',
261
+ >>> scope_name='custom')
262
+ >>> # Check default scope
263
+ >>> DefaultScope.get_current_instance().scope_name
264
+ custom
265
+ >>> # Switch to mmcls scope and get `MMCLS_MODELS` registry.
266
+ >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry:
267
+ >>> DefaultScope.get_current_instance().scope_name
268
+ mmcls
269
+ >>> registry.scope
270
+ mmcls
271
+ >>> # Nested switch scope
272
+ >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry:
273
+ >>> DefaultScope.get_current_instance().scope_name
274
+ mmdet
275
+ >>> mmdet_registry.scope
276
+ mmdet
277
+ >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry:
278
+ >>> DefaultScope.get_current_instance().scope_name
279
+ mmcls
280
+ >>> mmcls_registry.scope
281
+ mmcls
282
+ >>>
283
+ >>> # Check switch back to original scope.
284
+ >>> DefaultScope.get_current_instance().scope_name
285
+ custom
286
+ """ # noqa: E501
287
+
288
+ # Switch to the given scope temporarily. If the corresponding registry
289
+ # can be found in root registry, return the registry under the scope,
290
+ # otherwise return the registry itself.
291
+ with DefaultScope.overwrite_default_scope(scope):
292
+ # Get the global default scope
293
+ default_scope = DefaultScope.get_current_instance()
294
+ # Get registry by scope
295
+ if default_scope is not None:
296
+ scope_name = default_scope.scope_name
297
+ try:
298
+ import_module(f'{scope_name}.registry')
299
+ except (ImportError, AttributeError, ModuleNotFoundError):
300
+ if scope in MODULE2PACKAGE:
301
+ print(
302
+ f'{scope} is not installed and its '
303
+ 'modules will not be registered. If you '
304
+ 'want to use modules defined in '
305
+ f'{scope}, Please install {scope} by '
306
+ f'`pip install {MODULE2PACKAGE[scope]}.')
307
+ else:
308
+ print(
309
+ f'Failed to import `{scope}.registry` '
310
+ f'make sure the registry.py exists in `{scope}` '
311
+ 'package.',)
312
+ root = self._get_root_registry()
313
+ registry = root._search_child(scope_name)
314
+ if registry is None:
315
+ # if `default_scope` can not be found, fallback to argument
316
+ # `registry`
317
+ print(
318
+ f'Failed to search registry with scope "{scope_name}" '
319
+ f'in the "{root.name}" registry tree. '
320
+ f'As a workaround, the current "{self.name}" registry '
321
+ f'in "{self.scope}" is used to build instance. This '
322
+ 'may cause unexpected failure when running the built '
323
+ f'modules. Please check whether "{scope_name}" is a '
324
+ 'correct scope, or whether the registry is '
325
+ 'initialized.',)
326
+ registry = self
327
+ # If there is no built default scope, just return current registry.
328
+ else:
329
+ registry = self
330
+ yield registry
331
+
332
+ def _get_root_registry(self) -> 'Registry':
333
+ """Return the root registry."""
334
+ root = self
335
+ while root.parent is not None:
336
+ root = root.parent
337
+ return root
338
+
339
+ def import_from_location(self) -> None:
340
+ """Import modules from the pre-defined locations in self._location."""
341
+ if not self._imported:
342
+ # avoid BC breaking
343
+ if len(self._locations) == 0 and self.scope in MODULE2PACKAGE:
344
+ print(
345
+ f'The "{self.name}" registry in {self.scope} did not '
346
+ 'set import location. Fallback to call '
347
+ f'`{self.scope}.utils.register_all_modules` '
348
+ 'instead.',)
349
+ try:
350
+ module = import_module(f'{self.scope}.utils')
351
+ except (ImportError, AttributeError, ModuleNotFoundError):
352
+ if self.scope in MODULE2PACKAGE:
353
+ print(
354
+ f'{self.scope} is not installed and its '
355
+ 'modules will not be registered. If you '
356
+ 'want to use modules defined in '
357
+ f'{self.scope}, Please install {self.scope} by '
358
+ f'`pip install {MODULE2PACKAGE[self.scope]}.',)
359
+ else:
360
+ print(
361
+ f'Failed to import {self.scope} and register '
362
+ 'its modules, please make sure you '
363
+ 'have registered the module manually.',)
364
+ else:
365
+ # The import errors triggered during the registration
366
+ # may be more complex, here just throwing
367
+ # the error to avoid causing more implicit registry errors
368
+ # like `xxx`` not found in `yyy` registry.
369
+ module.register_all_modules(False) # type: ignore
370
+
371
+ for loc in self._locations:
372
+ import_module(loc)
373
+ print(
374
+ f"Modules of {self.scope}'s {self.name} registry have "
375
+ f'been automatically imported from {loc}',)
376
+ self._imported = True
377
+
378
+ def get(self, key: str) -> Optional[Type]:
379
+ """Get the registry record.
380
+
381
+ If `key`` represents the whole object name with its module
382
+ information, for example, `mmengine.model.BaseModel`, ``get``
383
+ will directly return the class object :class:`BaseModel`.
384
+
385
+ Otherwise, it will first parse ``key`` and check whether it
386
+ contains a scope name. The logic to search for ``key``:
387
+
388
+ - ``key`` does not contain a scope name, i.e., it is purely a module
389
+ name like "ResNet": :meth:`get` will search for ``ResNet`` from the
390
+ current registry to its parent or ancestors until finding it.
391
+
392
+ - ``key`` contains a scope name and it is equal to the scope of the
393
+ current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get`
394
+ will only search for ``ResNet`` in the current registry.
395
+
396
+ - ``key`` contains a scope name and it is not equal to the scope of
397
+ the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the
398
+ scope exists in its children, :meth:`get` will get "FCNet" from
399
+ them. If not, :meth:`get` will first get the root registry and root
400
+ registry call its own :meth:`get` method.
401
+
402
+ Args:
403
+ key (str): Name of the registered item, e.g., the class name in
404
+ string format.
405
+
406
+ Returns:
407
+ Type or None: Return the corresponding class if ``key`` exists,
408
+ otherwise return None.
409
+
410
+ Examples:
411
+ >>> # define a registry
412
+ >>> MODELS = Registry('models')
413
+ >>> # register `ResNet` to `MODELS`
414
+ >>> @MODELS.register_module()
415
+ >>> class ResNet:
416
+ >>> pass
417
+ >>> resnet_cls = MODELS.get('ResNet')
418
+
419
+ >>> # hierarchical registry
420
+ >>> DETECTORS = Registry('detector', parent=MODELS, scope='det')
421
+ >>> # `ResNet` does not exist in `DETECTORS` but `get` method
422
+ >>> # will try to search from its parents or ancestors
423
+ >>> resnet_cls = DETECTORS.get('ResNet')
424
+ >>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls')
425
+ >>> @CLASSIFIER.register_module()
426
+ >>> class MobileNet:
427
+ >>> pass
428
+ >>> # `get` from its sibling registries
429
+ >>> mobilenet_cls = DETECTORS.get('cls.MobileNet')
430
+ """
431
+
432
+ if not isinstance(key, str):
433
+ raise TypeError(
434
+ 'The key argument of `Registry.get` must be a str, '
435
+ f'got {type(key)}')
436
+
437
+ scope, real_key = self.split_scope_key(key)
438
+ obj_cls = None
439
+ registry_name = self.name
440
+ scope_name = self.scope
441
+
442
+ # lazy import the modules to register them into the registry
443
+ self.import_from_location()
444
+
445
+ if scope is None or scope == self._scope:
446
+ # get from self
447
+ if real_key in self._module_dict:
448
+ obj_cls = self._module_dict[real_key]
449
+ elif scope is None:
450
+ # try to get the target from its parent or ancestors
451
+ parent = self.parent
452
+ while parent is not None:
453
+ if real_key in parent._module_dict:
454
+ obj_cls = parent._module_dict[real_key]
455
+ registry_name = parent.name
456
+ scope_name = parent.scope
457
+ break
458
+ parent = parent.parent
459
+ else:
460
+ # import the registry to add the nodes into the registry tree
461
+ try:
462
+ import_module(f'{scope}.registry')
463
+ print(
464
+ f'Registry node of {scope} has been automatically '
465
+ 'imported.',)
466
+ except (ImportError, AttributeError, ModuleNotFoundError):
467
+ print(
468
+ f'Cannot auto import {scope}.registry, please check '
469
+ f'whether the package "{scope}" is installed correctly '
470
+ 'or import the registry manually.',)
471
+ # get from self._children
472
+ if scope in self._children:
473
+ obj_cls = self._children[scope].get(real_key)
474
+ registry_name = self._children[scope].name
475
+ scope_name = scope
476
+ else:
477
+ root = self._get_root_registry()
478
+
479
+ if scope != root._scope and scope not in root._children:
480
+ # If not skip directly, `root.get(key)` will recursively
481
+ # call itself until RecursionError is thrown.
482
+ pass
483
+ else:
484
+ obj_cls = root.get(key)
485
+
486
+ if obj_cls is None:
487
+ # Actually, it's strange to implement this `try ... except` to
488
+ # get the object by its name in `Registry.get`. However, If we
489
+ # want to build the model using a configuration like
490
+ # `dict(type='mmengine.model.BaseModel')`, which can
491
+ # be dumped by lazy import config, we need this code snippet
492
+ # for `Registry.get` to work.
493
+ try:
494
+ obj_cls = get_object_from_string(key)
495
+ except Exception:
496
+ raise RuntimeError(f'Failed to get {key}')
497
+
498
+ if obj_cls is not None:
499
+ # For some rare cases (e.g. obj_cls is a partial function), obj_cls
500
+ # doesn't have `__name__`. Use default value to prevent error
501
+ cls_name = getattr(obj_cls, '__name__', str(obj_cls))
502
+ return obj_cls
503
+
504
+ def _search_child(self, scope: str) -> Optional['Registry']:
505
+ """Depth-first search for the corresponding registry in its children.
506
+
507
+ Note that the method only search for the corresponding registry from
508
+ the current registry. Therefore, if we want to search from the root
509
+ registry, :meth:`_get_root_registry` should be called to get the
510
+ root registry first.
511
+
512
+ Args:
513
+ scope (str): The scope name used for searching for its
514
+ corresponding registry.
515
+
516
+ Returns:
517
+ Registry or None: Return the corresponding registry if ``scope``
518
+ exists, otherwise return None.
519
+ """
520
+ if self._scope == scope:
521
+ return self
522
+
523
+ for child in self._children.values():
524
+ registry = child._search_child(scope)
525
+ if registry is not None:
526
+ return registry
527
+
528
+ return None
529
+
530
+ def build(self, cfg: dict, *args, **kwargs) -> Any:
531
+ """Build an instance.
532
+
533
+ Build an instance by calling :attr:`build_func`.
534
+
535
+ Args:
536
+ cfg (dict): Config dict needs to be built.
537
+
538
+ Returns:
539
+ Any: The constructed object.
540
+
541
+ Examples:
542
+ >>> from mmengine import Registry
543
+ >>> MODELS = Registry('models')
544
+ >>> @MODELS.register_module()
545
+ >>> class ResNet:
546
+ >>> def __init__(self, depth, stages=4):
547
+ >>> self.depth = depth
548
+ >>> self.stages = stages
549
+ >>> cfg = dict(type='ResNet', depth=50)
550
+ >>> model = MODELS.build(cfg)
551
+ """
552
+ return self.build_func(cfg, *args, **kwargs, registry=self)
553
+
554
+ def _add_child(self, registry: 'Registry') -> None:
555
+ """Add a child for a registry.
556
+
557
+ Args:
558
+ registry (:obj:`Registry`): The ``registry`` will be added as a
559
+ child of the ``self``.
560
+ """
561
+
562
+ assert isinstance(registry, Registry)
563
+ assert registry.scope is not None
564
+ assert registry.scope not in self.children, \
565
+ f'scope {registry.scope} exists in {self.name} registry'
566
+ self.children[registry.scope] = registry
567
+
568
+ def _register_module(self,
569
+ module: Type,
570
+ module_name: Optional[Union[str, List[str]]] = None,
571
+ force: bool = False) -> None:
572
+ """Register a module.
573
+
574
+ Args:
575
+ module (type): Module to be registered. Typically a class or a
576
+ function, but generally all ``Callable`` are acceptable.
577
+ module_name (str or list of str, optional): The module name to be
578
+ registered. If not specified, the class name will be used.
579
+ Defaults to None.
580
+ force (bool): Whether to override an existing class with the same
581
+ name. Defaults to False.
582
+ """
583
+ if not callable(module):
584
+ raise TypeError(f'module must be Callable, but got {type(module)}')
585
+
586
+ if module_name is None:
587
+ module_name = module.__name__
588
+ if isinstance(module_name, str):
589
+ module_name = [module_name]
590
+ for name in module_name:
591
+ if not force and name in self._module_dict:
592
+ existed_module = self.module_dict[name]
593
+ raise KeyError(f'{name} is already registered in {self.name} '
594
+ f'at {existed_module.__module__}')
595
+ self._module_dict[name] = module
596
+
597
+ def register_module(
598
+ self,
599
+ name: Optional[Union[str, List[str]]] = None,
600
+ force: bool = False,
601
+ module: Optional[Type] = None) -> Union[type, Callable]:
602
+ """Register a module.
603
+
604
+ A record will be added to ``self._module_dict``, whose key is the class
605
+ name or the specified name, and value is the class itself.
606
+ It can be used as a decorator or a normal function.
607
+
608
+ Args:
609
+ name (str or list of str, optional): The module name to be
610
+ registered. If not specified, the class name will be used.
611
+ force (bool): Whether to override an existing class with the same
612
+ name. Defaults to False.
613
+ module (type, optional): Module class or function to be registered.
614
+ Defaults to None.
615
+
616
+ Examples:
617
+ >>> backbones = Registry('backbone')
618
+ >>> # as a decorator
619
+ >>> @backbones.register_module()
620
+ >>> class ResNet:
621
+ >>> pass
622
+ >>> backbones = Registry('backbone')
623
+ >>> @backbones.register_module(name='mnet')
624
+ >>> class MobileNet:
625
+ >>> pass
626
+
627
+ >>> # as a normal function
628
+ >>> class ResNet:
629
+ >>> pass
630
+ >>> backbones.register_module(module=ResNet)
631
+ """
632
+ if not isinstance(force, bool):
633
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
634
+
635
+ # raise the error ahead of time
636
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
637
+ raise TypeError(
638
+ 'name must be None, an instance of str, or a sequence of str, '
639
+ f'but got {type(name)}')
640
+
641
+ # use it as a normal method: x.register_module(module=SomeClass)
642
+ if module is not None:
643
+ self._register_module(module=module, module_name=name, force=force)
644
+ return module
645
+
646
+ # use it as a decorator: @x.register_module()
647
+ def _register(module):
648
+ self._register_module(module=module, module_name=name, force=force)
649
+ return module
650
+
651
+ return _register
652
+
653
+
654
+ def is_seq_of(seq: Any,
655
+ expected_type: Union[Type, tuple],
656
+ seq_type: Optional[Type] = None) -> bool:
657
+ """Check whether it is a sequence of some type.
658
+
659
+ Args:
660
+ seq (Sequence): The sequence to be checked.
661
+ expected_type (type or tuple): Expected type of sequence items.
662
+ seq_type (type, optional): Expected sequence type. Defaults to None.
663
+
664
+ Returns:
665
+ bool: Return True if ``seq`` is valid else False.
666
+
667
+ Examples:
668
+ >>> from mmengine.utils import is_seq_of
669
+ >>> seq = ['a', 'b', 'c']
670
+ >>> is_seq_of(seq, str)
671
+ True
672
+ >>> is_seq_of(seq, int)
673
+ False
674
+ """
675
+ if seq_type is None:
676
+ exp_seq_type = abc.Sequence
677
+ else:
678
+ assert isinstance(seq_type, type)
679
+ exp_seq_type = seq_type
680
+ if not isinstance(seq, exp_seq_type):
681
+ return False
682
+ for item in seq:
683
+ if not isinstance(item, expected_type):
684
+ return False
685
+ return True
686
+
687
+
688
+ def get_object_from_string(obj_name: str):
689
+ """Get object from name.
690
+
691
+ Args:
692
+ obj_name (str): The name of the object.
693
+
694
+ Examples:
695
+ >>> get_object_from_string('torch.optim.sgd.SGD')
696
+ >>> torch.optim.sgd.SGD
697
+ """
698
+ parts = iter(obj_name.split('.'))
699
+ module_name = next(parts)
700
+ # import module
701
+ while True:
702
+ try:
703
+ module = import_module(module_name)
704
+ part = next(parts)
705
+ # mmcv.ops has nms.py and nms function at the same time. So the
706
+ # function will have a higher priority
707
+ obj = getattr(module, part, None)
708
+ if obj is not None and not ismodule(obj):
709
+ break
710
+ module_name = f'{module_name}.{part}'
711
+ except StopIteration:
712
+ # if obj is a module
713
+ return module
714
+ except ImportError:
715
+ return None
716
+
717
+ # get class or attribute from module
718
+ obj = module
719
+ while True:
720
+ try:
721
+ obj = getattr(obj, part)
722
+ part = next(parts)
723
+ except StopIteration:
724
+ return obj
725
+ except AttributeError:
726
+ return None
727
+
728
+ def ismodule(object):
729
+ """Return true if the object is a module.
730
+
731
+ Module objects provide these attributes:
732
+ __cached__ pathname to byte compiled file
733
+ __doc__ documentation string
734
+ __file__ filename (missing for built-in modules)"""
735
+ return isinstance(object, types.ModuleType)
segformer_plusplus/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .build_model import create_model, create_custom_model
2
+ from .random_benchmark import random_benchmark
3
+
4
+ __all__ = ['create_model', 'create_custom_model', 'random_benchmark']
segformer_plusplus/build_model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from .utils import MODELS, imagenet_weights
4
+ from .utils import tome_presets
5
+ from .model.base_module import BaseModule
6
+ from .configs.config.config import Config
7
+ from .utils.build_functions import build_model_from_cfg
8
+
9
+
10
+ class SegFormer(BaseModule):
11
+ """
12
+ This class represents a SegFormer model that allows for the application of token merging.
13
+
14
+ Attributes:
15
+ backbone (BaseModule): MixVisionTransformer backbone
16
+ decode_head (BaseModule): SegFormer head
17
+
18
+ """
19
+ def __init__(self, cfg):
20
+ """
21
+ Initialize the SegFormer model.
22
+
23
+ Args:
24
+ cfg (Config): an mmengine Config object, which defines the backbone, head and token merging strategy used.
25
+
26
+ """
27
+ super().__init__()
28
+ self.backbone = build_model_from_cfg(cfg.backbone, registry=MODELS)
29
+ self.decode_head = build_model_from_cfg(cfg.decode_head, registry=MODELS)
30
+
31
+ def forward(self, x):
32
+ """
33
+ Forward pass of the model.
34
+
35
+ Args:
36
+ x (torch.Tensor): input tensor of shape [B, C, H, W]
37
+
38
+ Returns:
39
+ torch.Tensor: output tensor
40
+
41
+ """
42
+ x = self.backbone(x)
43
+ x = self.decode_head(x)
44
+ return x
45
+
46
+
47
+ def create_model(
48
+ backbone: str = 'b0',
49
+ tome_strategy: str = None,
50
+ out_channels: int = 19,
51
+ pretrained: bool = False,
52
+ ):
53
+ """
54
+ Create a SegFormer model using the predefined SegFormer backbones from the MiT series (b0-b5).
55
+
56
+ Args:
57
+ backbone (str): backbone name (e.g. 'b0')
58
+ tome_strategy (str | list(dict)): select strategy from presets ('bsm_hq', 'bsm_fast', 'n2d_2x2') or define a
59
+ custom strategy using a list, that contains of dictionaries, in which the strategies for the stage are
60
+ defined
61
+ out_channels (int): number of output channels (e.g. 19 for the cityscapes semantic segmentation task)
62
+ pretrained: use pretrained (imagenet) weights
63
+
64
+ Returns:
65
+ BaseModule: SegFormer model
66
+
67
+ """
68
+ backbone = backbone.lower()
69
+ assert backbone in [f'b{i}' for i in range(6)]
70
+
71
+ wd = os.path.dirname(os.path.abspath(__file__))
72
+
73
+ cfg = Config.fromfile(os.path.join(wd, 'configs', f'segformer_mit_{backbone}.py'))
74
+
75
+ cfg.decode_head.out_channels = out_channels
76
+
77
+ if tome_strategy is not None:
78
+ if tome_strategy not in list(tome_presets.keys()):
79
+ print("Using custom merging strategy.")
80
+ cfg.backbone.tome_cfg = tome_presets[tome_strategy]
81
+
82
+ # load imagenet weights
83
+ if pretrained:
84
+ cfg.backbone.init_cfg = dict(type='Pretrained', checkpoint=imagenet_weights[backbone])
85
+
86
+ return SegFormer(cfg)
87
+
88
+
89
+ def create_custom_model(
90
+ model_cfg: Config,
91
+ tome_strategy: list[dict] = None,
92
+ ):
93
+ """
94
+ Create a SegFormer model with customizable backbone and head.
95
+
96
+ Args:
97
+ model_cfg (Config): backbone name (e.g. 'b0')
98
+ tome_strategy (list(dict)): custom token merging strategy
99
+
100
+ Returns:
101
+ BaseModule: SegFormer model
102
+
103
+ """
104
+ if tome_strategy is not None:
105
+ model_cfg.backbone.tome_cfg = tome_strategy
106
+
107
+ return SegFormer(model_cfg)
segformer_plusplus/cityscape_benchmark.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import torchvision.transforms as T
4
+ import os
5
+ from typing import Union, List, Tuple
6
+ import numpy as np
7
+
8
+ from .utils.benchmark import benchmark
9
+
10
+
11
+ # Gerät auswählen
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ print(f"Using device: {device}")
14
+ if device.type == 'cuda':
15
+ print(f"CUDA Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
16
+
17
+
18
+ def cityscape_benchmark(
19
+ model: torch.nn.Module,
20
+ image_path: str,
21
+ batch_size: Union[int, List[int]] = 1,
22
+ image_size: Union[Tuple[int], List[Tuple[int]]] = (3, 1024, 1024),
23
+ save_output: bool = True,
24
+
25
+ ):
26
+ """
27
+ Calculate the FPS of a given model using an actual Cityscapes image.
28
+
29
+ Args:
30
+ model: instance of a model (e.g. SegFormer)
31
+ image_path: the path to the Cityscapes image
32
+ batch_size: the batch size(s) at which to calculate the FPS (e.g. 1 or [1, 2, 4])
33
+ image_size: the size of the images to use (e.g. (3, 1024, 1024))
34
+ save_output: whether to save the output prediction (default True)
35
+
36
+ Returns:
37
+ the FPS values calculated for all image sizes and batch sizes in the form of a dictionary
38
+ """
39
+
40
+
41
+ if isinstance(batch_size, int):
42
+ batch_size = [batch_size]
43
+ if isinstance(image_size, tuple):
44
+ image_size = [image_size]
45
+
46
+ values = {}
47
+ throughput_values = []
48
+
49
+ model = model.to(device)
50
+ model.eval()
51
+
52
+ assert os.path.exists(image_path), f"Image not found: {image_path}"
53
+ image = Image.open(image_path).convert("RGB")
54
+
55
+ img_tensor = T.ToTensor()(image)
56
+ mean = img_tensor.mean(dim=(1, 2))
57
+ std = img_tensor.std(dim=(1, 2))
58
+ print(f"Calculated Mean: {mean}")
59
+ print(f"Calculated Std: {std}")
60
+
61
+ transform = T.Compose([
62
+ T.Resize((image_size[0][1], image_size[0][2])),
63
+ T.ToTensor(),
64
+ T.Normalize(mean=mean.tolist(),
65
+ std=std.tolist())
66
+ ])
67
+
68
+ img_tensor = transform(image).unsqueeze(0).to(device)
69
+
70
+ for i in image_size:
71
+ # fill with fps for each batch size
72
+ fps = []
73
+ for b in batch_size:
74
+ for _ in range(4):
75
+ # Baseline benchmark
76
+ if i[1] >= 1024:
77
+ r = 16
78
+ else:
79
+ r = 32
80
+ baseline_throughput = benchmark(
81
+ model.to(device),
82
+ device=device,
83
+ verbose=True,
84
+ runs=r,
85
+ batch_size=b,
86
+ input_size=i
87
+ )
88
+ throughput_values.append(baseline_throughput)
89
+ throughput_values = np.asarray(throughput_values)
90
+ throughput = np.around(np.mean(throughput_values), decimals=2)
91
+ print('Im_size:', i, 'Batch_size:', b, 'Mean:', throughput, 'Std:',
92
+ np.around(np.std(throughput_values), decimals=2))
93
+ throughput_values = []
94
+ fps.append({b: throughput})
95
+ values[i] = fps
96
+
97
+ if save_output:
98
+ with torch.no_grad():
99
+ with open("model_output_log.txt", "w") as f:
100
+ f.write("=== Model Input Info ===\n")
101
+ f.write(f"Input tensor:\n{img_tensor}\n")
102
+ f.write(f"Input shape: {img_tensor.shape}\n")
103
+ f.write(f"Input stats: mean = {img_tensor.mean().item()}, std = {img_tensor.std().item()}\n\n")
104
+
105
+ output = model(img_tensor)
106
+
107
+ f.write("=== Raw Model Output ===\n")
108
+ f.write(f"{output}\n\n")
109
+
110
+ pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
111
+
112
+ # Speichere Prediction als Text ab
113
+ np.savetxt("cityscapes_prediction_output.txt", pred, fmt="%d")
114
+
115
+ print("Prediction saved as cityscapes_prediction_output.txt")
116
+
117
+ return values
segformer_plusplus/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __all__ = []
segformer_plusplus/configs/config/config.py ADDED
@@ -0,0 +1,1545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import copy
3
+ import os
4
+ import os.path as osp
5
+ import platform
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+ import types
10
+ import uuid
11
+ import re
12
+ import warnings
13
+ from argparse import ArgumentParser
14
+ from collections import OrderedDict, abc
15
+ from pathlib import Path
16
+ from typing import Any, Optional, Tuple, Union
17
+ from omegaconf import OmegaConf
18
+ import yapf
19
+ from addict import Dict
20
+ from yapf.yapflib.yapf_api import FormatCode
21
+
22
+ from .lazy import LazyAttr, LazyObject
23
+ from .utils import (check_file_exist, get_installed_path, import_modules_from_strings, is_installed, RemoveAssignFromAST,
24
+ ImportTransformer, _gather_abs_import_lazyobj, _get_external_cfg_base_path,
25
+ _get_external_cfg_path, _get_package_and_cfg_path, _is_builtin_module, dump)
26
+
27
+
28
+ BASE_KEY = '_base_'
29
+ DELETE_KEY = '_delete_'
30
+ DEPRECATION_KEY = '_deprecation_'
31
+ RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'env_variables']
32
+
33
+
34
+ def _lazy2string(cfg_dict, dict_type=None):
35
+ if isinstance(cfg_dict, dict):
36
+ dict_type = dict_type or type(cfg_dict)
37
+ return dict_type(
38
+ {k: _lazy2string(v, dict_type)
39
+ for k, v in dict.items(cfg_dict)})
40
+ elif isinstance(cfg_dict, (tuple, list)):
41
+ return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict)
42
+ elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
43
+ return f'{cfg_dict.module}.{str(cfg_dict)}'
44
+ else:
45
+ return cfg_dict
46
+
47
+
48
+ class ConfigDict(Dict):
49
+ """A dictionary for config which has the same interface as python's built-
50
+ in dictionary and can be used as a normal dictionary.
51
+
52
+ The Config class would transform the nested fields (dictionary-like fields)
53
+ in config file into ``ConfigDict``.
54
+
55
+ If the class attribute ``lazy`` is ``False``, users will get the
56
+ object built by ``LazyObject`` or ``LazyAttr``, otherwise users will get
57
+ the ``LazyObject`` or ``LazyAttr`` itself.
58
+
59
+ The ``lazy`` should be set to ``True`` to avoid building the imported
60
+ object during configuration parsing, and it should be set to False outside
61
+ the Config to ensure that users do not experience the ``LazyObject``.
62
+ """
63
+ lazy = False
64
+
65
+ def __init__(__self, *args, **kwargs):
66
+ object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
67
+ object.__setattr__(__self, '__key', kwargs.pop('__key', None))
68
+ object.__setattr__(__self, '__frozen', False)
69
+ for arg in args:
70
+ if not arg:
71
+ continue
72
+ # Since ConfigDict.items will convert LazyObject to real object
73
+ # automatically, we need to call super().items() to make sure
74
+ # the LazyObject will not be converted.
75
+ if isinstance(arg, ConfigDict):
76
+ for key, val in dict.items(arg):
77
+ __self[key] = __self._hook(val)
78
+ elif isinstance(arg, dict):
79
+ for key, val in arg.items():
80
+ __self[key] = __self._hook(val)
81
+ elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
82
+ __self[arg[0]] = __self._hook(arg[1])
83
+ else:
84
+ for key, val in iter(arg):
85
+ __self[key] = __self._hook(val)
86
+
87
+ for key, val in dict.items(kwargs):
88
+ __self[key] = __self._hook(val)
89
+
90
+ def __missing__(self, name):
91
+ raise KeyError(name)
92
+
93
+ def __getattr__(self, name):
94
+ try:
95
+ value = super().__getattr__(name)
96
+ if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy:
97
+ value = value.build()
98
+ except KeyError:
99
+ raise AttributeError(f"'{self.__class__.__name__}' object has no "
100
+ f"attribute '{name}'")
101
+ except Exception as e:
102
+ raise e
103
+ else:
104
+ return value
105
+
106
+ @classmethod
107
+ def _hook(cls, item):
108
+ # avoid to convert user defined dict to ConfigDict.
109
+ if type(item) in (dict, OrderedDict):
110
+ return cls(item)
111
+ elif isinstance(item, (list, tuple)):
112
+ return type(item)(cls._hook(elem) for elem in item)
113
+ return item
114
+
115
+ def __setattr__(self, name, value):
116
+ value = self._hook(value)
117
+ return super().__setattr__(name, value)
118
+
119
+ def __setitem__(self, name, value):
120
+ value = self._hook(value)
121
+ return super().__setitem__(name, value)
122
+
123
+ def __getitem__(self, key):
124
+ return self.build_lazy(super().__getitem__(key))
125
+
126
+ def __deepcopy__(self, memo):
127
+ other = self.__class__()
128
+ memo[id(self)] = other
129
+ for key, value in super().items():
130
+ other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
131
+ return other
132
+
133
+ def __copy__(self):
134
+ other = self.__class__()
135
+ for key, value in super().items():
136
+ other[key] = value
137
+ return other
138
+
139
+ copy = __copy__
140
+
141
+ def __iter__(self):
142
+ # Implement `__iter__` to overwrite the unpacking operator `**cfg_dict`
143
+ # to get the built lazy object
144
+ return iter(self.keys())
145
+
146
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
147
+ """Get the value of the key. If class attribute ``lazy`` is True, the
148
+ LazyObject will be built and returned.
149
+
150
+ Args:
151
+ key (str): The key.
152
+ default (any, optional): The default value. Defaults to None.
153
+
154
+ Returns:
155
+ Any: The value of the key.
156
+ """
157
+ return self.build_lazy(super().get(key, default))
158
+
159
+ def pop(self, key, default=None):
160
+ """Pop the value of the key. If class attribute ``lazy`` is True, the
161
+ LazyObject will be built and returned.
162
+
163
+ Args:
164
+ key (str): The key.
165
+ default (any, optional): The default value. Defaults to None.
166
+
167
+ Returns:
168
+ Any: The value of the key.
169
+ """
170
+ return self.build_lazy(super().pop(key, default))
171
+
172
+ def update(self, *args, **kwargs) -> None:
173
+ """Override this method to make sure the LazyObject will not be built
174
+ during updating."""
175
+ other = {}
176
+ if args:
177
+ if len(args) > 1:
178
+ raise TypeError('update only accept one positional argument')
179
+ # Avoid to used self.items to build LazyObject
180
+ for key, value in dict.items(args[0]):
181
+ other[key] = value
182
+
183
+ for key, value in dict(kwargs).items():
184
+ other[key] = value
185
+ for k, v in other.items():
186
+ if ((k not in self) or (not isinstance(self[k], dict))
187
+ or (not isinstance(v, dict))):
188
+ self[k] = self._hook(v)
189
+ else:
190
+ self[k].update(v)
191
+
192
+ def build_lazy(self, value: Any) -> Any:
193
+ """If class attribute ``lazy`` is False, the LazyObject will be built
194
+ and returned.
195
+
196
+ Args:
197
+ value (Any): The value to be built.
198
+
199
+ Returns:
200
+ Any: The built value.
201
+ """
202
+ if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy:
203
+ value = value.build()
204
+ return value
205
+
206
+ def values(self):
207
+ """Yield the values of the dictionary.
208
+
209
+ If class attribute ``lazy`` is False, the value of ``LazyObject`` or
210
+ ``LazyAttr`` will be built and returned.
211
+ """
212
+ values = []
213
+ for value in super().values():
214
+ values.append(self.build_lazy(value))
215
+ return values
216
+
217
+ def items(self):
218
+ """Yield the keys and values of the dictionary.
219
+
220
+ If class attribute ``lazy`` is False, the value of ``LazyObject`` or
221
+ ``LazyAttr`` will be built and returned.
222
+ """
223
+ items = []
224
+ for key, value in super().items():
225
+ items.append((key, self.build_lazy(value)))
226
+ return items
227
+
228
+ def merge(self, other: dict):
229
+ """Merge another dictionary into current dictionary.
230
+
231
+ Args:
232
+ other (dict): Another dictionary.
233
+ """
234
+ default = object()
235
+
236
+ def _merge_a_into_b(a, b):
237
+ if isinstance(a, dict):
238
+ if not isinstance(b, dict):
239
+ a.pop(DELETE_KEY, None)
240
+ return a
241
+ if a.pop(DELETE_KEY, False):
242
+ b.clear()
243
+ all_keys = list(b.keys()) + list(a.keys())
244
+ return {
245
+ key:
246
+ _merge_a_into_b(a.get(key, default), b.get(key, default))
247
+ for key in all_keys if key != DELETE_KEY
248
+ }
249
+ else:
250
+ return a if a is not default else b
251
+
252
+ merged = _merge_a_into_b(copy.deepcopy(other), copy.deepcopy(self))
253
+ self.clear()
254
+ for key, value in merged.items():
255
+ self[key] = value
256
+
257
+ def __reduce_ex__(self):
258
+ # Override __reduce_ex__ to avoid `self.items` will be
259
+ # called by CPython interpreter during pickling. See more details in
260
+ # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501
261
+ from ...utils import digit_version
262
+ if digit_version(platform.python_version()) < digit_version('3.8'):
263
+ return (self.__class__, ({k: v
264
+ for k, v in super().items()}, ), None,
265
+ None, None)
266
+ else:
267
+ return (self.__class__, ({k: v
268
+ for k, v in super().items()}, ), None,
269
+ None, None, None)
270
+
271
+ def __eq__(self, other):
272
+ if isinstance(other, ConfigDict):
273
+ return other.to_dict() == self.to_dict()
274
+ elif isinstance(other, dict):
275
+ return {k: v for k, v in self.items()} == other
276
+ else:
277
+ return False
278
+
279
+ def _to_lazy_dict(self):
280
+ """Convert the ConfigDict to a normal dictionary recursively, and keep
281
+ the ``LazyObject`` or ``LazyAttr`` object not built."""
282
+
283
+ def _to_dict(data):
284
+ if isinstance(data, ConfigDict):
285
+ return {
286
+ key: _to_dict(value)
287
+ for key, value in Dict.items(data)
288
+ }
289
+ elif isinstance(data, dict):
290
+ return {key: _to_dict(value) for key, value in data.items()}
291
+ elif isinstance(data, (list, tuple)):
292
+ return type(data)(_to_dict(item) for item in data)
293
+ else:
294
+ return data
295
+
296
+ return _to_dict(self)
297
+
298
+ def to_dict(self):
299
+ """Convert the ConfigDict to a normal dictionary recursively, and
300
+ convert the ``LazyObject`` or ``LazyAttr`` to string."""
301
+ return _lazy2string(self, dict_type=dict)
302
+
303
+
304
+ def add_args(parser: ArgumentParser,
305
+ cfg: dict,
306
+ prefix: str = '') -> ArgumentParser:
307
+ """Add config fields into argument parser.
308
+
309
+ Args:
310
+ parser (ArgumentParser): Argument parser.
311
+ cfg (dict): Config dictionary.
312
+ prefix (str, optional): Prefix of parser argument.
313
+ Defaults to ''.
314
+
315
+ Returns:
316
+ ArgumentParser: Argument parser containing config fields.
317
+ """
318
+ for k, v in cfg.items():
319
+ if isinstance(v, str):
320
+ parser.add_argument('--' + prefix + k)
321
+ elif isinstance(v, bool):
322
+ parser.add_argument('--' + prefix + k, action='store_true')
323
+ elif isinstance(v, int):
324
+ parser.add_argument('--' + prefix + k, type=int)
325
+ elif isinstance(v, float):
326
+ parser.add_argument('--' + prefix + k, type=float)
327
+ elif isinstance(v, dict):
328
+ add_args(parser, v, prefix + k + '.')
329
+ elif isinstance(v, abc.Iterable):
330
+ parser.add_argument(
331
+ '--' + prefix + k, type=type(next(iter(v))), nargs='+')
332
+ return parser
333
+
334
+
335
+ class Config:
336
+ """A facility for config and config files.
337
+
338
+ It supports common file formats as configs: python/json/yaml.
339
+ ``Config.fromfile`` can parse a dictionary from a config file, then
340
+ build a ``Config`` instance with the dictionary.
341
+ The interface is the same as a dict object and also allows access config
342
+ values as attributes.
343
+
344
+ Args:
345
+ cfg_dict (dict, optional): A config dictionary. Defaults to None.
346
+ cfg_text (str, optional): Text of config. Defaults to None.
347
+ filename (str or Path, optional): Name of config file.
348
+ Defaults to None.
349
+ format_python_code (bool): Whether to format Python code by yapf.
350
+ Defaults to True.
351
+ """ # noqa: E501
352
+
353
+ def __init__(
354
+ self,
355
+ cfg_dict: Optional[dict] = None,
356
+ cfg_text: Optional[str] = None,
357
+ filename: Optional[Union[str, Path]] = None,
358
+ env_variables: Optional[dict] = None,
359
+ format_python_code: bool = True,
360
+ ):
361
+ filename = str(filename) if isinstance(filename, Path) else filename
362
+ if cfg_dict is None:
363
+ cfg_dict = dict()
364
+ elif not isinstance(cfg_dict, dict):
365
+ raise TypeError('cfg_dict must be a dict, but '
366
+ f'got {type(cfg_dict)}')
367
+ for key in cfg_dict:
368
+ if key in RESERVED_KEYS:
369
+ raise KeyError(f'{key} is reserved for config file')
370
+
371
+ if not isinstance(cfg_dict, ConfigDict):
372
+ cfg_dict = ConfigDict(cfg_dict)
373
+ super().__setattr__('_cfg_dict', cfg_dict)
374
+ super().__setattr__('_filename', filename)
375
+ super().__setattr__('_format_python_code', format_python_code)
376
+ if not hasattr(self, '_imported_names'):
377
+ super().__setattr__('_imported_names', set())
378
+
379
+ if cfg_text:
380
+ text = cfg_text
381
+ elif filename:
382
+ with open(filename, encoding='utf-8') as f:
383
+ text = f.read()
384
+ else:
385
+ text = ''
386
+ super().__setattr__('_text', text)
387
+ if env_variables is None:
388
+ env_variables = dict()
389
+ super().__setattr__('_env_variables', env_variables)
390
+
391
+ @staticmethod
392
+ def fromfile(filename: Union[str, Path],
393
+ use_predefined_variables: bool = True,
394
+ import_custom_modules: bool = True,
395
+ use_environment_variables: bool = True,
396
+ lazy_import: Optional[bool] = None,
397
+ format_python_code: bool = True) -> 'Config':
398
+ """Build a Config instance from config file.
399
+
400
+ Args:
401
+ filename (str or Path): Name of config file.
402
+ use_predefined_variables (bool, optional): Whether to use
403
+ predefined variables. Defaults to True.
404
+ import_custom_modules (bool, optional): Whether to support
405
+ importing custom modules in config. Defaults to None.
406
+ use_environment_variables (bool, optional): Whether to use
407
+ environment variables. Defaults to True.
408
+ lazy_import (bool): Whether to load config in `lazy_import` mode.
409
+ If it is `None`, it will be deduced by the content of the
410
+ config file. Defaults to None.
411
+ format_python_code (bool): Whether to format Python code by yapf.
412
+ Defaults to True.
413
+
414
+ Returns:
415
+ Config: Config instance built from config file.
416
+ """
417
+ filename = str(filename) if isinstance(filename, Path) else filename
418
+ if lazy_import is False or \
419
+ lazy_import is None and not Config._is_lazy_import(filename):
420
+ cfg_dict, cfg_text, env_variables = Config._file2dict(
421
+ filename, use_predefined_variables, use_environment_variables,
422
+ lazy_import)
423
+ if import_custom_modules and cfg_dict.get('custom_imports', None):
424
+ try:
425
+ import_modules_from_strings(**cfg_dict['custom_imports'])
426
+ except ImportError as e:
427
+ err_msg = (
428
+ 'Failed to import custom modules from '
429
+ f"{cfg_dict['custom_imports']}, the current sys.path "
430
+ 'is: ')
431
+ for p in sys.path:
432
+ err_msg += f'\n {p}'
433
+ err_msg += (
434
+ '\nYou should set `PYTHONPATH` to make `sys.path` '
435
+ 'include the directory which contains your custom '
436
+ 'module')
437
+ raise ImportError(err_msg) from e
438
+ return Config(
439
+ cfg_dict,
440
+ cfg_text=cfg_text,
441
+ filename=filename,
442
+ env_variables=env_variables,
443
+ )
444
+ else:
445
+ # Enable lazy import when parsing the config.
446
+ # Using try-except to make sure ``ConfigDict.lazy`` will be reset
447
+ # to False. See more details about lazy in the docstring of
448
+ # ConfigDict
449
+ ConfigDict.lazy = True
450
+ try:
451
+ cfg_dict, imported_names = Config._parse_lazy_import(filename)
452
+ except Exception as e:
453
+ raise e
454
+ finally:
455
+ # disable lazy import to get the real type. See more details
456
+ # about lazy in the docstring of ConfigDict
457
+ ConfigDict.lazy = False
458
+
459
+ cfg = Config(
460
+ cfg_dict,
461
+ filename=filename,
462
+ format_python_code=format_python_code)
463
+ object.__setattr__(cfg, '_imported_names', imported_names)
464
+ return cfg
465
+
466
+ @staticmethod
467
+ def _get_base_modules(nodes: list) -> list:
468
+ """Get base module name from parsed code.
469
+
470
+ Args:
471
+ nodes (list): Parsed code of the config file.
472
+
473
+ Returns:
474
+ list: Name of base modules.
475
+ """
476
+
477
+ def _get_base_module_from_with(with_nodes: list) -> list:
478
+ """Get base module name from if statement in python file.
479
+
480
+ Args:
481
+ with_nodes (list): List of if statement.
482
+
483
+ Returns:
484
+ list: Name of base modules.
485
+ """
486
+ base_modules = []
487
+ for node in with_nodes:
488
+ assert isinstance(node, ast.ImportFrom), (
489
+ 'Illegal syntax in config file! Only '
490
+ '`from ... import ...` could be implemented` in '
491
+ 'with read_base()`')
492
+ assert node.module is not None, (
493
+ 'Illegal syntax in config file! Syntax like '
494
+ '`from . import xxx` is not allowed in `with read_base()`')
495
+ base_modules.append(node.level * '.' + node.module)
496
+ return base_modules
497
+
498
+ for idx, node in enumerate(nodes):
499
+ if (isinstance(node, ast.Assign)
500
+ and isinstance(node.targets[0], ast.Name)
501
+ and node.targets[0].id == BASE_KEY):
502
+ raise SyntaxError(
503
+ 'The configuration file type in the inheritance chain '
504
+ 'must match the current configuration file type, either '
505
+ '"lazy_import" or non-"lazy_import". You got this error '
506
+ f'since you use the syntax like `_base_ = "{node.targets[0].id}"` ' # noqa: E501
507
+ 'in your config. You should use `with read_base(): ... to` ' # noqa: E501
508
+ 'mark the inherited config file. See more information '# noqa: E501
509
+ )
510
+
511
+ if not isinstance(node, ast.With):
512
+ continue
513
+
514
+ expr = node.items[0].context_expr
515
+ if (not isinstance(expr, ast.Call)
516
+ or not expr.func.id == 'read_base' or # type: ignore
517
+ len(node.items) > 1):
518
+ raise SyntaxError(
519
+ 'Only `read_base` context manager can be used in the '
520
+ 'config')
521
+ for nested_idx, nested_node in enumerate(node.body):
522
+ nodes.insert(idx + nested_idx + 1, nested_node)
523
+ nodes.pop(idx)
524
+ return _get_base_module_from_with(node.body)
525
+ return []
526
+
527
+ @staticmethod
528
+ def _validate_py_syntax(filename: str):
529
+ """Validate syntax of python config.
530
+
531
+ Args:
532
+ filename (str): Filename of python config file.
533
+ """
534
+ with open(filename, encoding='utf-8') as f:
535
+ content = f.read()
536
+ try:
537
+ ast.parse(content)
538
+ except SyntaxError as e:
539
+ raise SyntaxError('There are syntax errors in config '
540
+ f'file {filename}: {e}')
541
+
542
+ @staticmethod
543
+ def _substitute_predefined_vars(filename: str, temp_config_name: str):
544
+ """Substitute predefined variables in config with actual values.
545
+
546
+ Sometimes we want some variables in the config to be related to the
547
+ current path or file name, etc.
548
+
549
+ Here is an example of a typical usage scenario. When training a model,
550
+ we define a working directory in the config that save the models and
551
+ logs. For different configs, we expect to define different working
552
+ directories. A common way for users is to use the config file name
553
+ directly as part of the working directory name, e.g. for the config
554
+ ``config_setting1.py``, the working directory is
555
+ ``. /work_dir/config_setting1``.
556
+
557
+ This can be easily achieved using predefined variables, which can be
558
+ written in the config `config_setting1.py` as follows
559
+
560
+ .. code-block:: python
561
+
562
+ work_dir = '. /work_dir/{{ fileBasenameNoExtension }}'
563
+
564
+
565
+ Here `{{ fileBasenameNoExtension }}` indicates the file name of the
566
+ config (without the extension), and when the config class reads the
567
+ config file, it will automatically parse this double-bracketed string
568
+ to the corresponding actual value.
569
+
570
+ .. code-block:: python
571
+
572
+ cfg = Config.fromfile('. /config_setting1.py')
573
+ cfg.work_dir # ". /work_dir/config_setting1"
574
+
575
+
576
+ For details, Please refer to docs/zh_cn/advanced_tutorials/config.md .
577
+
578
+ Args:
579
+ filename (str): Filename of config.
580
+ temp_config_name (str): Temporary filename to save substituted
581
+ config.
582
+ """
583
+ file_dirname = osp.dirname(filename)
584
+ file_basename = osp.basename(filename)
585
+ file_basename_no_extension = osp.splitext(file_basename)[0]
586
+ file_extname = osp.splitext(filename)[1]
587
+ support_templates = dict(
588
+ fileDirname=file_dirname,
589
+ fileBasename=file_basename,
590
+ fileBasenameNoExtension=file_basename_no_extension,
591
+ fileExtname=file_extname)
592
+ with open(filename, encoding='utf-8') as f:
593
+ config_file = f.read()
594
+ for key, value in support_templates.items():
595
+ regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
596
+ value = value.replace('\\', '/')
597
+ config_file = re.sub(regexp, value, config_file)
598
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
599
+ tmp_config_file.write(config_file)
600
+
601
+ @staticmethod
602
+ def _substitute_env_variables(filename: str, temp_config_name: str):
603
+ """Substitute environment variables in config with actual values.
604
+
605
+ Sometimes, we want to change some items in the config with environment
606
+ variables. For examples, we expect to change dataset root by setting
607
+ ``DATASET_ROOT=/dataset/root/path`` in the command line. This can be
608
+ easily achieved by writing lines in the config as follows
609
+
610
+ .. code-block:: python
611
+
612
+ data_root = '{{$DATASET_ROOT:/default/dataset}}/images'
613
+
614
+
615
+ Here, ``{{$DATASET_ROOT:/default/dataset}}`` indicates using the
616
+ environment variable ``DATASET_ROOT`` to replace the part between
617
+ ``{{}}``. If the ``DATASET_ROOT`` is not set, the default value
618
+ ``/default/dataset`` will be used.
619
+
620
+ Environment variables not only can replace items in the string, they
621
+ can also substitute other types of data in config. In this situation,
622
+ we can write the config as below
623
+
624
+ .. code-block:: python
625
+
626
+ model = dict(
627
+ bbox_head = dict(num_classes={{'$NUM_CLASSES:80'}}))
628
+
629
+
630
+ For details, Please refer to docs/zh_cn/tutorials/config.md .
631
+
632
+ Args:
633
+ filename (str): Filename of config.
634
+ temp_config_name (str): Temporary filename to save substituted
635
+ config.
636
+ """
637
+ with open(filename, encoding='utf-8') as f:
638
+ config_file = f.read()
639
+ regexp = r'\{\{[\'\"]?\s*\$(\w+)\s*\:\s*(\S*?)\s*[\'\"]?\}\}'
640
+ keys = re.findall(regexp, config_file)
641
+ env_variables = dict()
642
+ for var_name, value in keys:
643
+ regexp = r'\{\{[\'\"]?\s*\$' + var_name + r'\s*\:\s*' \
644
+ + value + r'\s*[\'\"]?\}\}'
645
+ if var_name in os.environ:
646
+ value = os.environ[var_name]
647
+ env_variables[var_name] = value
648
+ if not value:
649
+ raise KeyError(f'`{var_name}` cannot be found in `os.environ`.'
650
+ f' Please set `{var_name}` in environment or '
651
+ 'give a default value.')
652
+ config_file = re.sub(regexp, value, config_file)
653
+
654
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
655
+ tmp_config_file.write(config_file)
656
+ return env_variables
657
+
658
+ @staticmethod
659
+ def _pre_substitute_base_vars(filename: str,
660
+ temp_config_name: str) -> dict:
661
+ """Preceding step for substituting variables in base config with actual
662
+ value.
663
+
664
+ Args:
665
+ filename (str): Filename of config.
666
+ temp_config_name (str): Temporary filename to save substituted
667
+ config.
668
+
669
+ Returns:
670
+ dict: A dictionary contains variables in base config.
671
+ """
672
+ with open(filename, encoding='utf-8') as f:
673
+ config_file = f.read()
674
+ base_var_dict = {}
675
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
676
+ base_vars = set(re.findall(regexp, config_file))
677
+ for base_var in base_vars:
678
+ randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
679
+ base_var_dict[randstr] = base_var
680
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
681
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
682
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
683
+ tmp_config_file.write(config_file)
684
+ return base_var_dict
685
+
686
+ @staticmethod
687
+ def _substitute_base_vars(cfg: Any, base_var_dict: dict,
688
+ base_cfg: dict) -> Any:
689
+ """Substitute base variables from strings to their actual values.
690
+
691
+ Args:
692
+ Any : Config dictionary.
693
+ base_var_dict (dict): A dictionary contains variables in base
694
+ config.
695
+ base_cfg (dict): Base config dictionary.
696
+
697
+ Returns:
698
+ Any : A dictionary with origin base variables
699
+ substituted with actual values.
700
+ """
701
+ cfg = copy.deepcopy(cfg)
702
+
703
+ if isinstance(cfg, dict):
704
+ for k, v in cfg.items():
705
+ if isinstance(v, str) and v in base_var_dict:
706
+ new_v = base_cfg
707
+ for new_k in base_var_dict[v].split('.'):
708
+ new_v = new_v[new_k]
709
+ cfg[k] = new_v
710
+ elif isinstance(v, (list, tuple, dict)):
711
+ cfg[k] = Config._substitute_base_vars(
712
+ v, base_var_dict, base_cfg)
713
+ elif isinstance(cfg, tuple):
714
+ cfg = tuple(
715
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
716
+ for c in cfg)
717
+ elif isinstance(cfg, list):
718
+ cfg = [
719
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
720
+ for c in cfg
721
+ ]
722
+ elif isinstance(cfg, str) and cfg in base_var_dict:
723
+ new_v = base_cfg
724
+ for new_k in base_var_dict[cfg].split('.'):
725
+ new_v = new_v[new_k]
726
+ cfg = new_v
727
+
728
+ return cfg
729
+
730
+ @staticmethod
731
+ def _file2dict(
732
+ filename: str,
733
+ use_predefined_variables: bool = True,
734
+ use_environment_variables: bool = True,
735
+ lazy_import: Optional[bool] = None) -> Tuple[dict, str, dict]:
736
+ """Transform file to variables dictionary.
737
+
738
+ Args:
739
+ filename (str): Name of config file.
740
+ use_predefined_variables (bool, optional): Whether to use
741
+ predefined variables. Defaults to True.
742
+ use_environment_variables (bool, optional): Whether to use
743
+ environment variables. Defaults to True.
744
+ lazy_import (bool): Whether to load config in `lazy_import` mode.
745
+ If it is `None`, it will be deduced by the content of the
746
+ config file. Defaults to None.
747
+
748
+ Returns:
749
+ Tuple[dict, str]: Variables dictionary and text of Config.
750
+ """
751
+ if lazy_import is None and Config._is_lazy_import(filename):
752
+ raise RuntimeError(
753
+ 'The configuration file type in the inheritance chain '
754
+ 'must match the current configuration file type, either '
755
+ '"lazy_import" or non-"lazy_import". You got this error '
756
+ 'since you use the syntax like `with read_base(): ...` '
757
+ f'or import non-builtin module in {filename}.' # noqa: E501
758
+ )
759
+
760
+ filename = osp.abspath(osp.expanduser(filename))
761
+ check_file_exist(filename)
762
+ fileExtname = osp.splitext(filename)[1]
763
+ if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
764
+ raise OSError('Only py/yml/yaml/json type are supported now!')
765
+ try:
766
+ with tempfile.TemporaryDirectory() as temp_config_dir:
767
+ temp_config_file = tempfile.NamedTemporaryFile(
768
+ dir=temp_config_dir, suffix=fileExtname, delete=False)
769
+ if platform.system() == 'Windows':
770
+ temp_config_file.close()
771
+
772
+ # Substitute predefined variables
773
+ if use_predefined_variables:
774
+ Config._substitute_predefined_vars(filename,
775
+ temp_config_file.name)
776
+ else:
777
+ shutil.copyfile(filename, temp_config_file.name)
778
+ # Substitute environment variables
779
+ env_variables = dict()
780
+ if use_environment_variables:
781
+ env_variables = Config._substitute_env_variables(
782
+ temp_config_file.name, temp_config_file.name)
783
+ # Substitute base variables from placeholders to strings
784
+ base_var_dict = Config._pre_substitute_base_vars(
785
+ temp_config_file.name, temp_config_file.name)
786
+
787
+ # Handle base files
788
+ base_cfg_dict = ConfigDict()
789
+ cfg_text_list = list()
790
+ for base_cfg_path in Config._get_base_files(
791
+ temp_config_file.name):
792
+ base_cfg_path, scope = Config._get_cfg_path(
793
+ base_cfg_path, filename)
794
+ _cfg_dict, _cfg_text, _env_variables = Config._file2dict(
795
+ filename=base_cfg_path,
796
+ use_predefined_variables=use_predefined_variables,
797
+ use_environment_variables=use_environment_variables,
798
+ lazy_import=lazy_import,
799
+ )
800
+ cfg_text_list.append(_cfg_text)
801
+ env_variables.update(_env_variables)
802
+ duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys()
803
+ if len(duplicate_keys) > 0:
804
+ raise KeyError(
805
+ 'Duplicate key is not allowed among bases. '
806
+ f'Duplicate keys: {duplicate_keys}')
807
+
808
+ # _dict_to_config_dict will do the following things:
809
+ # 1. Recursively converts ``dict`` to :obj:`ConfigDict`.
810
+ # 2. Set `_scope_` for the outer dict variable for the base
811
+ # config.
812
+ # 3. Set `scope` attribute for each base variable.
813
+ # Different from `_scope_`, `scope` is not a key of base
814
+ # dict, `scope` attribute will be parsed to key `_scope_`
815
+ # by function `_parse_scope` only if the base variable is
816
+ # accessed by the current config.
817
+ _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope)
818
+ base_cfg_dict.update(_cfg_dict)
819
+
820
+ if filename.endswith('.py'):
821
+ with open(temp_config_file.name, encoding='utf-8') as f:
822
+ parsed_codes = ast.parse(f.read())
823
+ parsed_codes = RemoveAssignFromAST(BASE_KEY).visit(
824
+ parsed_codes)
825
+ codeobj = compile(parsed_codes, filename, mode='exec')
826
+ # Support load global variable in nested function of the
827
+ # config.
828
+ global_locals_var = {BASE_KEY: base_cfg_dict}
829
+ ori_keys = set(global_locals_var.keys())
830
+ eval(codeobj, global_locals_var, global_locals_var)
831
+ cfg_dict = {
832
+ key: value
833
+ for key, value in global_locals_var.items()
834
+ if (key not in ori_keys and not key.startswith('__'))
835
+ }
836
+ elif filename.endswith(('.yml', '.yaml', '.json')):
837
+ cfg = OmegaConf.load(temp_config_file.name)
838
+ cfg_dict = OmegaConf.to_container(cfg, resolve=True)
839
+ # close temp file
840
+ for key, value in list(cfg_dict.items()):
841
+ if isinstance(value,
842
+ (types.FunctionType, types.ModuleType)):
843
+ cfg_dict.pop(key)
844
+ temp_config_file.close()
845
+
846
+ # If the current config accesses a base variable of base
847
+ # configs, The ``scope`` attribute of corresponding variable
848
+ # will be converted to the `_scope_`.
849
+ Config._parse_scope(cfg_dict)
850
+ except Exception as e:
851
+ if osp.exists(temp_config_dir):
852
+ shutil.rmtree(temp_config_dir)
853
+ raise e
854
+
855
+ # check deprecation information
856
+ if DEPRECATION_KEY in cfg_dict:
857
+ deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
858
+ warning_msg = f'The config file {filename} will be deprecated ' \
859
+ 'in the future.'
860
+ if 'expected' in deprecation_info:
861
+ warning_msg += f' Please use {deprecation_info["expected"]} ' \
862
+ 'instead.'
863
+ if 'reference' in deprecation_info:
864
+ warning_msg += ' More information can be found at ' \
865
+ f'{deprecation_info["reference"]}'
866
+ warnings.warn(warning_msg, DeprecationWarning)
867
+
868
+ cfg_text = filename + '\n'
869
+ with open(filename, encoding='utf-8') as f:
870
+ # Setting encoding explicitly to resolve coding issue on windows
871
+ cfg_text += f.read()
872
+
873
+ # Substitute base variables from strings to their actual values
874
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
875
+ base_cfg_dict)
876
+ cfg_dict.pop(BASE_KEY, None)
877
+
878
+ cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
879
+ cfg_dict = {
880
+ k: v
881
+ for k, v in cfg_dict.items() if not k.startswith('__')
882
+ }
883
+
884
+ # merge cfg_text
885
+ cfg_text_list.append(cfg_text)
886
+ cfg_text = '\n'.join(cfg_text_list)
887
+
888
+ return cfg_dict, cfg_text, env_variables
889
+
890
+ @staticmethod
891
+ def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]:
892
+ """Transform file to variables dictionary.
893
+
894
+ Args:
895
+ filename (str): Name of config file.
896
+
897
+ Returns:
898
+ Tuple[dict, dict]: ``cfg_dict`` and ``imported_names``.
899
+
900
+ - cfg_dict (dict): Variables dictionary of parsed config.
901
+ - imported_names (set): Used to mark the names of
902
+ imported object.
903
+ """
904
+ # In lazy import mode, users can use the Python syntax `import` to
905
+ # implement inheritance between configuration files, which is easier
906
+ # for users to understand the hierarchical relationships between
907
+ # different configuration files.
908
+
909
+ # Besides, users can also using `import` syntax to import corresponding
910
+ # module which will be filled in the `type` field. It means users
911
+ # can directly navigate to the source of the module in the
912
+ # configuration file by clicking the `type` field.
913
+
914
+ # To avoid really importing the third party package like `torch`
915
+ # during import `type` object, we use `_parse_lazy_import` to parse the
916
+ # configuration file, which will not actually trigger the import
917
+ # process, but simply parse the imported `type`s as LazyObject objects.
918
+
919
+ # The overall pipeline of _parse_lazy_import is:
920
+ # 1. Parse the base module from the config file.
921
+ # ||
922
+ # \/
923
+ # base_module = ['mmdet.configs.default_runtime']
924
+ # ||
925
+ # \/
926
+ # 2. recursively parse the base module and gather imported objects to
927
+ # a dict.
928
+ # ||
929
+ # \/
930
+ # The base_dict will be:
931
+ # {
932
+ # 'mmdet.configs.default_runtime': {...}
933
+ # 'mmdet.configs.retinanet_r50_fpn_1x_coco': {...}
934
+ # ...
935
+ # }, each item in base_dict is a dict of `LazyObject`
936
+ # 3. parse the current config file filling the imported variable
937
+ # with the base_dict.
938
+ #
939
+ # 4. During the parsing process, all imported variable will be
940
+ # recorded in the `imported_names` set. These variables can be
941
+ # accessed, but will not be dumped by default.
942
+
943
+ with open(filename, encoding='utf-8') as f:
944
+ global_dict = {'LazyObject': LazyObject, '__file__': filename}
945
+ base_dict = {}
946
+
947
+ parsed_codes = ast.parse(f.read())
948
+ # get the names of base modules, and remove the
949
+ # `with read_base():'` statement
950
+ base_modules = Config._get_base_modules(parsed_codes.body)
951
+ base_imported_names = set()
952
+ for base_module in base_modules:
953
+ # If base_module means a relative import, assuming the level is
954
+ # 2, which means the module is imported like
955
+ # "from ..a.b import c". we must ensure that c is an
956
+ # object `defined` in module b, and module b should not be a
957
+ # package including `__init__` file but a single python file.
958
+ level = len(re.match(r'\.*', base_module).group())
959
+ if level > 0:
960
+ # Relative import
961
+ base_dir = osp.dirname(filename)
962
+ module_path = osp.join(
963
+ base_dir, *(['..'] * (level - 1)),
964
+ f'{base_module[level:].replace(".", "/")}.py')
965
+ else:
966
+ # Absolute import
967
+ module_list = base_module.split('.')
968
+ if len(module_list) == 1:
969
+ raise SyntaxError(
970
+ 'The imported configuration file should not be '
971
+ f'an independent package {module_list[0]}. Here '
972
+ 'is an example: '
973
+ '`with read_base(): from mmdet.configs.retinanet_r50_fpn_1x_coco import *`' # noqa: E501
974
+ )
975
+ else:
976
+ package = module_list[0]
977
+ root_path = get_installed_path(package)
978
+ module_path = f'{osp.join(root_path, *module_list[1:])}.py' # noqa: E501
979
+ if not osp.isfile(module_path):
980
+ raise SyntaxError(
981
+ f'{module_path} not found! It means that incorrect '
982
+ 'module is defined in '
983
+ f'`with read_base(): = from {base_module} import ...`, please ' # noqa: E501
984
+ 'make sure the base config module is valid '
985
+ 'and is consistent with the prior import '
986
+ 'logic')
987
+ _base_cfg_dict, _base_imported_names = Config._parse_lazy_import( # noqa: E501
988
+ module_path)
989
+ base_imported_names |= _base_imported_names
990
+ # The base_dict will be:
991
+ # {
992
+ # 'mmdet.configs.default_runtime': {...}
993
+ # 'mmdet.configs.retinanet_r50_fpn_1x_coco': {...}
994
+ # ...
995
+ # }
996
+ base_dict[base_module] = _base_cfg_dict
997
+
998
+ # `base_dict` contains all the imported modules from `base_cfg`.
999
+ # In order to collect the specific imported module from `base_cfg`
1000
+ # before parse the current file, we using AST Transform to
1001
+ # transverse the imported module from base_cfg and merge then into
1002
+ # the global dict. After the ast transformation, most of import
1003
+ # syntax will be removed (except for the builtin import) and
1004
+ # replaced with the `LazyObject`
1005
+ transform = ImportTransformer(
1006
+ global_dict=global_dict,
1007
+ base_dict=base_dict,
1008
+ filename=filename)
1009
+ modified_code = transform.visit(parsed_codes)
1010
+ modified_code, abs_imported = _gather_abs_import_lazyobj(
1011
+ modified_code, filename=filename)
1012
+ imported_names = transform.imported_obj | abs_imported
1013
+ imported_names |= base_imported_names
1014
+ modified_code = ast.fix_missing_locations(modified_code)
1015
+ exec(
1016
+ compile(modified_code, filename, mode='exec'), global_dict,
1017
+ global_dict)
1018
+
1019
+ ret: dict = {}
1020
+ for key, value in global_dict.items():
1021
+ if key.startswith('__') or key in ['LazyObject']:
1022
+ continue
1023
+ ret[key] = value
1024
+ # convert dict to ConfigDict
1025
+ cfg_dict = Config._dict_to_config_dict_lazy(ret)
1026
+
1027
+ return cfg_dict, imported_names
1028
+
1029
+ @staticmethod
1030
+ def _dict_to_config_dict_lazy(cfg: dict):
1031
+ """Recursively converts ``dict`` to :obj:`ConfigDict`. The only
1032
+ difference between ``_dict_to_config_dict_lazy`` and
1033
+ ``_dict_to_config_dict_lazy`` is that the former one does not consider
1034
+ the scope, and will not trigger the building of ``LazyObject``.
1035
+
1036
+ Args:
1037
+ cfg (dict): Config dict.
1038
+
1039
+ Returns:
1040
+ ConfigDict: Converted dict.
1041
+ """
1042
+ # Only the outer dict with key `type` should have the key `_scope_`.
1043
+ if isinstance(cfg, dict):
1044
+ cfg_dict = ConfigDict()
1045
+ for key, value in cfg.items():
1046
+ cfg_dict[key] = Config._dict_to_config_dict_lazy(value)
1047
+ return cfg_dict
1048
+ if isinstance(cfg, (tuple, list)):
1049
+ return type(cfg)(
1050
+ Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg)
1051
+ return cfg
1052
+
1053
+ @staticmethod
1054
+ def _dict_to_config_dict(cfg: dict,
1055
+ scope: Optional[str] = None,
1056
+ has_scope=True):
1057
+ """Recursively converts ``dict`` to :obj:`ConfigDict`.
1058
+
1059
+ Args:
1060
+ cfg (dict): Config dict.
1061
+ scope (str, optional): Scope of instance.
1062
+ has_scope (bool): Whether to add `_scope_` key to config dict.
1063
+
1064
+ Returns:
1065
+ ConfigDict: Converted dict.
1066
+ """
1067
+ # Only the outer dict with key `type` should have the key `_scope_`.
1068
+ if isinstance(cfg, dict):
1069
+ if has_scope and 'type' in cfg:
1070
+ has_scope = False
1071
+ if scope is not None and cfg.get('_scope_', None) is None:
1072
+ cfg._scope_ = scope # type: ignore
1073
+ cfg = ConfigDict(cfg)
1074
+ dict.__setattr__(cfg, 'scope', scope)
1075
+ for key, value in cfg.items():
1076
+ cfg[key] = Config._dict_to_config_dict(
1077
+ value, scope=scope, has_scope=has_scope)
1078
+ elif isinstance(cfg, tuple):
1079
+ cfg = tuple(
1080
+ Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope)
1081
+ for _cfg in cfg)
1082
+ elif isinstance(cfg, list):
1083
+ cfg = [
1084
+ Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope)
1085
+ for _cfg in cfg
1086
+ ]
1087
+ return cfg
1088
+
1089
+ @staticmethod
1090
+ def _parse_scope(cfg: dict) -> None:
1091
+ """Adds ``_scope_`` to :obj:`ConfigDict` instance, which means a base
1092
+ variable.
1093
+
1094
+ If the config dict already has the scope, scope will not be
1095
+ overwritten.
1096
+
1097
+ Args:
1098
+ cfg (dict): Config needs to be parsed with scope.
1099
+ """
1100
+ if isinstance(cfg, ConfigDict):
1101
+ cfg._scope_ = cfg.scope
1102
+ elif isinstance(cfg, (tuple, list)):
1103
+ [Config._parse_scope(value) for value in cfg]
1104
+ else:
1105
+ return
1106
+
1107
+ @staticmethod
1108
+ def _get_base_files(filename: str) -> list:
1109
+ """Get the base config file.
1110
+
1111
+ Args:
1112
+ filename (str): The config file.
1113
+
1114
+ Raises:
1115
+ TypeError: Name of config file.
1116
+
1117
+ Returns:
1118
+ list: A list of base config.
1119
+ """
1120
+ file_format = osp.splitext(filename)[1]
1121
+ if file_format == '.py':
1122
+ Config._validate_py_syntax(filename)
1123
+ with open(filename, encoding='utf-8') as f:
1124
+ parsed_codes = ast.parse(f.read()).body
1125
+
1126
+ def is_base_line(c):
1127
+ return (isinstance(c, ast.Assign)
1128
+ and isinstance(c.targets[0], ast.Name)
1129
+ and c.targets[0].id == BASE_KEY)
1130
+
1131
+ base_code = next((c for c in parsed_codes if is_base_line(c)),
1132
+ None)
1133
+ if base_code is not None:
1134
+ base_code = ast.Expression( # type: ignore
1135
+ body=base_code.value) # type: ignore
1136
+ base_files = eval(compile(base_code, '',
1137
+ mode='eval')) # type: ignore
1138
+ else:
1139
+ base_files = []
1140
+ elif file_format in ('.yml', '.yaml', '.json'):
1141
+ cfg = OmegaConf.load(filename)
1142
+ cfg_dict = OmegaConf.to_container(cfg, resolve=True)
1143
+ base_files = cfg_dict.get(BASE_KEY, [])
1144
+ else:
1145
+ raise SyntaxError(
1146
+ 'The config type should be py, json, yaml or '
1147
+ f'yml, but got {file_format}')
1148
+ base_files = base_files if isinstance(base_files,
1149
+ list) else [base_files]
1150
+ return base_files
1151
+
1152
+ @staticmethod
1153
+ def _get_cfg_path(cfg_path: str,
1154
+ filename: str) -> Tuple[str, Optional[str]]:
1155
+ """Get the config path from the current or external package.
1156
+
1157
+ Args:
1158
+ cfg_path (str): Relative path of config.
1159
+ filename (str): The config file being parsed.
1160
+
1161
+ Returns:
1162
+ Tuple[str, str or None]: Path and scope of config. If the config
1163
+ is not an external config, the scope will be `None`.
1164
+ """
1165
+ if '::' in cfg_path:
1166
+ # `cfg_path` startswith '::' means an external config path.
1167
+ # Get package name and relative config path.
1168
+ scope = cfg_path.partition('::')[0]
1169
+ package, cfg_path = _get_package_and_cfg_path(cfg_path)
1170
+
1171
+ if not is_installed(package):
1172
+ raise ModuleNotFoundError(
1173
+ f'{package} is not installed, please install {package} '
1174
+ f'manually')
1175
+
1176
+ # Get installed package path.
1177
+ package_path = get_installed_path(package)
1178
+ try:
1179
+ # Get config path from meta file.
1180
+ cfg_path = _get_external_cfg_path(package_path, cfg_path)
1181
+ except ValueError:
1182
+ # Since base config does not have a metafile, it should be
1183
+ # concatenated with package path and relative config path.
1184
+ cfg_path = _get_external_cfg_base_path(package_path, cfg_path)
1185
+ except FileNotFoundError as e:
1186
+ raise e
1187
+ return cfg_path, scope
1188
+ else:
1189
+ # Get local config path.
1190
+ cfg_dir = osp.dirname(filename)
1191
+ cfg_path = osp.join(cfg_dir, cfg_path)
1192
+ return cfg_path, None
1193
+
1194
+ @staticmethod
1195
+ def _merge_a_into_b(a: dict,
1196
+ b: dict,
1197
+ allow_list_keys: bool = False) -> dict:
1198
+ """Merge dict ``a`` into dict ``b`` (non-inplace).
1199
+
1200
+ Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
1201
+ in-place modifications.
1202
+
1203
+ Args:
1204
+ a (dict): The source dict to be merged into ``b``.
1205
+ b (dict): The origin dict to be fetch keys from ``a``.
1206
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
1207
+ are allowed in source ``a`` and will replace the element of the
1208
+ corresponding index in b if b is a list. Defaults to False.
1209
+
1210
+ Returns:
1211
+ dict: The modified dict of ``b`` using ``a``.
1212
+
1213
+ Examples:
1214
+ # Normally merge a into b.
1215
+ >>> Config._merge_a_into_b(
1216
+ ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
1217
+ {'obj': {'a': 2}}
1218
+
1219
+ # Delete b first and merge a into b.
1220
+ >>> Config._merge_a_into_b(
1221
+ ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
1222
+ {'obj': {'a': 2}}
1223
+
1224
+ # b is a list
1225
+ >>> Config._merge_a_into_b(
1226
+ ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
1227
+ [{'a': 2}, {'b': 2}]
1228
+ """
1229
+ b = b.copy()
1230
+ for k, v in a.items():
1231
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
1232
+ k = int(k)
1233
+ if len(b) <= k:
1234
+ raise KeyError(f'Index {k} exceeds the length of list {b}')
1235
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
1236
+ elif isinstance(v, dict):
1237
+ if k in b and not v.pop(DELETE_KEY, False):
1238
+ allowed_types: Union[Tuple, type] = (
1239
+ dict, list) if allow_list_keys else dict
1240
+ if not isinstance(b[k], allowed_types):
1241
+ raise TypeError(
1242
+ f'{k}={v} in child config cannot inherit from '
1243
+ f'base because {k} is a dict in the child config '
1244
+ f'but is of type {type(b[k])} in base config. '
1245
+ f'You may set `{DELETE_KEY}=True` to ignore the '
1246
+ f'base config.')
1247
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
1248
+ else:
1249
+ b[k] = ConfigDict(v)
1250
+ else:
1251
+ b[k] = v
1252
+ return b
1253
+
1254
+ @property
1255
+ def filename(self) -> str:
1256
+ """Get file name of config."""
1257
+ return self._filename
1258
+
1259
+ @property
1260
+ def text(self) -> str:
1261
+ """Get config text."""
1262
+ return self._text
1263
+
1264
+ @property
1265
+ def env_variables(self) -> dict:
1266
+ """Get used environment variables."""
1267
+ return self._env_variables
1268
+
1269
+ @property
1270
+ def pretty_text(self) -> str:
1271
+ """Get formatted python config text."""
1272
+
1273
+ indent = 4
1274
+
1275
+ def _indent(s_, num_spaces):
1276
+ s = s_.split('\n')
1277
+ if len(s) == 1:
1278
+ return s_
1279
+ first = s.pop(0)
1280
+ s = [(num_spaces * ' ') + line for line in s]
1281
+ s = '\n'.join(s)
1282
+ s = first + '\n' + s
1283
+ return s
1284
+
1285
+ def _format_basic_types(k, v, use_mapping=False):
1286
+ if isinstance(v, str):
1287
+ v_str = repr(v)
1288
+ else:
1289
+ v_str = str(v)
1290
+
1291
+ if use_mapping:
1292
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
1293
+ attr_str = f'{k_str}: {v_str}'
1294
+ else:
1295
+ attr_str = f'{str(k)}={v_str}'
1296
+ attr_str = _indent(attr_str, indent)
1297
+
1298
+ return attr_str
1299
+
1300
+ def _format_list_tuple(k, v, use_mapping=False):
1301
+ if isinstance(v, list):
1302
+ left = '['
1303
+ right = ']'
1304
+ else:
1305
+ left = '('
1306
+ right = ')'
1307
+
1308
+ v_str = f'{left}\n'
1309
+ # check if all items in the list are dict
1310
+ for item in v:
1311
+ if isinstance(item, dict):
1312
+ v_str += f'dict({_indent(_format_dict(item), indent)}),\n'
1313
+ elif isinstance(item, tuple):
1314
+ v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501
1315
+ elif isinstance(item, list):
1316
+ v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501
1317
+ elif isinstance(item, str):
1318
+ v_str += f'{_indent(repr(item), indent)},\n'
1319
+ else:
1320
+ v_str += str(item) + ',\n'
1321
+ if k is None:
1322
+ return _indent(v_str, indent) + right
1323
+ if use_mapping:
1324
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
1325
+ attr_str = f'{k_str}: {v_str}'
1326
+ else:
1327
+ attr_str = f'{str(k)}={v_str}'
1328
+ attr_str = _indent(attr_str, indent) + right
1329
+ return attr_str
1330
+
1331
+ def _contain_invalid_identifier(dict_str):
1332
+ contain_invalid_identifier = False
1333
+ for key_name in dict_str:
1334
+ contain_invalid_identifier |= \
1335
+ (not str(key_name).isidentifier())
1336
+ return contain_invalid_identifier
1337
+
1338
+ def _format_dict(input_dict, outest_level=False):
1339
+ r = ''
1340
+ s = []
1341
+
1342
+ use_mapping = _contain_invalid_identifier(input_dict)
1343
+ if use_mapping:
1344
+ r += '{'
1345
+ for idx, (k, v) in enumerate(
1346
+ sorted(input_dict.items(), key=lambda x: str(x[0]))):
1347
+ is_last = idx >= len(input_dict) - 1
1348
+ end = '' if outest_level or is_last else ','
1349
+ if isinstance(v, dict):
1350
+ v_str = '\n' + _format_dict(v)
1351
+ if use_mapping:
1352
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
1353
+ attr_str = f'{k_str}: dict({v_str}'
1354
+ else:
1355
+ attr_str = f'{str(k)}=dict({v_str}'
1356
+ attr_str = _indent(attr_str, indent) + ')' + end
1357
+ elif isinstance(v, (list, tuple)):
1358
+ attr_str = _format_list_tuple(k, v, use_mapping) + end
1359
+ else:
1360
+ attr_str = _format_basic_types(k, v, use_mapping) + end
1361
+
1362
+ s.append(attr_str)
1363
+ r += '\n'.join(s)
1364
+ if use_mapping:
1365
+ r += '}'
1366
+ return r
1367
+
1368
+ cfg_dict = self.to_dict()
1369
+ text = _format_dict(cfg_dict, outest_level=True)
1370
+ if self._format_python_code:
1371
+ # copied from setup.cfg
1372
+ yapf_style = dict(
1373
+ based_on_style='pep8',
1374
+ blank_line_before_nested_class_or_def=True,
1375
+ split_before_expression_after_opening_paren=True)
1376
+ try:
1377
+ from ...utils import digit_version
1378
+ if digit_version(yapf.__version__) >= digit_version('0.40.2'):
1379
+ text, _ = FormatCode(text, style_config=yapf_style)
1380
+ else:
1381
+ text, _ = FormatCode(
1382
+ text, style_config=yapf_style, verify=True)
1383
+ except: # noqa: E722
1384
+ raise SyntaxError('Failed to format the config file, please '
1385
+ f'check the syntax of: \n{text}')
1386
+ return text
1387
+
1388
+ def __repr__(self):
1389
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
1390
+
1391
+ def __len__(self):
1392
+ return len(self._cfg_dict)
1393
+
1394
+ def __getattr__(self, name: str) -> Any:
1395
+ return getattr(self._cfg_dict, name)
1396
+
1397
+ def __getitem__(self, name):
1398
+ return self._cfg_dict.__getitem__(name)
1399
+
1400
+ def __setattr__(self, name, value):
1401
+ if isinstance(value, dict):
1402
+ value = ConfigDict(value)
1403
+ self._cfg_dict.__setattr__(name, value)
1404
+
1405
+ def __setitem__(self, name, value):
1406
+ if isinstance(value, dict):
1407
+ value = ConfigDict(value)
1408
+ self._cfg_dict.__setitem__(name, value)
1409
+
1410
+ def __iter__(self):
1411
+ return iter(self._cfg_dict)
1412
+
1413
+ def __getstate__(
1414
+ self
1415
+ ) -> Tuple[dict, Optional[str], Optional[str], dict, bool, set]:
1416
+ state = (self._cfg_dict, self._filename, self._text,
1417
+ self._env_variables, self._format_python_code,
1418
+ self._imported_names)
1419
+ return state
1420
+
1421
+ def __deepcopy__(self, memo):
1422
+ cls = self.__class__
1423
+ other = cls.__new__(cls)
1424
+ memo[id(self)] = other
1425
+
1426
+ for key, value in self.__dict__.items():
1427
+ super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
1428
+
1429
+ return other
1430
+
1431
+ def __copy__(self):
1432
+ cls = self.__class__
1433
+ other = cls.__new__(cls)
1434
+ other.__dict__.update(self.__dict__)
1435
+ super(Config, other).__setattr__('_cfg_dict', self._cfg_dict.copy())
1436
+
1437
+ return other
1438
+
1439
+ copy = __copy__
1440
+
1441
+ def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
1442
+ dict, bool, set]):
1443
+ super().__setattr__('_cfg_dict', state[0])
1444
+ super().__setattr__('_filename', state[1])
1445
+ super().__setattr__('_text', state[2])
1446
+ super().__setattr__('_env_variables', state[3])
1447
+ super().__setattr__('_format_python_code', state[4])
1448
+ super().__setattr__('_imported_names', state[5])
1449
+
1450
+ def dump(self, file: Optional[Union[str, Path]] = None):
1451
+ """Dump config to file or return config text.
1452
+
1453
+ Args:
1454
+ file (str or Path, optional): If not specified, then the object
1455
+ is dumped to a str, otherwise to a file specified by the filename.
1456
+ Defaults to None.
1457
+
1458
+ Returns:
1459
+ str or None: Config text.
1460
+ """
1461
+ file = str(file) if isinstance(file, Path) else file
1462
+ cfg_dict = self.to_dict()
1463
+ if file is None:
1464
+ if self.filename is None or self.filename.endswith('.py'):
1465
+ return self.pretty_text
1466
+ else:
1467
+ file_format = self.filename.split('.')[-1]
1468
+ return dump(cfg_dict, file_format=file_format)
1469
+ elif file.endswith('.py'):
1470
+ with open(file, 'w', encoding='utf-8') as f:
1471
+ f.write(self.pretty_text)
1472
+ else:
1473
+ file_format = file.split('.')[-1]
1474
+ return dump(cfg_dict, file=file, file_format=file_format)
1475
+
1476
+ @staticmethod
1477
+ def _is_lazy_import(filename: str) -> bool:
1478
+ if not filename.endswith('.py'):
1479
+ return False
1480
+ with open(filename, encoding='utf-8') as f:
1481
+ codes_str = f.read()
1482
+ parsed_codes = ast.parse(codes_str)
1483
+ for node in ast.walk(parsed_codes):
1484
+ if (isinstance(node, ast.Assign)
1485
+ and isinstance(node.targets[0], ast.Name)
1486
+ and node.targets[0].id == BASE_KEY):
1487
+ return False
1488
+
1489
+ if isinstance(node, ast.With):
1490
+ expr = node.items[0].context_expr
1491
+ if (not isinstance(expr, ast.Call)
1492
+ or not expr.func.id == 'read_base'): # type: ignore
1493
+ raise SyntaxError(
1494
+ 'Only `read_base` context manager can be used in the '
1495
+ 'config')
1496
+ return True
1497
+ if isinstance(node, ast.ImportFrom):
1498
+ # relative import -> lazy_import
1499
+ if node.level != 0:
1500
+ return True
1501
+ # Skip checking when using `mmengine.config` in cfg file
1502
+ if (node.module == 'mmengine' and len(node.names) == 1
1503
+ and node.names[0].name == 'Config'):
1504
+ continue
1505
+ if not isinstance(node.module, str):
1506
+ continue
1507
+ # non-builtin module -> lazy_import
1508
+ if not _is_builtin_module(node.module):
1509
+ return True
1510
+ if isinstance(node, ast.Import):
1511
+ for alias_node in node.names:
1512
+ if not _is_builtin_module(alias_node.name):
1513
+ return True
1514
+ return False
1515
+
1516
+ def _to_lazy_dict(self, keep_imported: bool = False) -> dict:
1517
+ """Convert config object to dictionary with lazy object, and filter the
1518
+ imported object."""
1519
+ res = self._cfg_dict._to_lazy_dict()
1520
+ if hasattr(self, '_imported_names') and not keep_imported:
1521
+ res = {
1522
+ key: value
1523
+ for key, value in res.items()
1524
+ if key not in self._imported_names
1525
+ }
1526
+ return res
1527
+
1528
+ def to_dict(self, keep_imported: bool = False):
1529
+ """Convert all data in the config to a builtin ``dict``.
1530
+
1531
+ Args:
1532
+ keep_imported (bool): Whether to keep the imported field.
1533
+ Defaults to False
1534
+
1535
+ If you import third-party objects in the config file, all imported
1536
+ objects will be converted to a string like ``torch.optim.SGD``
1537
+ """
1538
+ cfg_dict = self._cfg_dict.to_dict()
1539
+ if hasattr(self, '_imported_names') and not keep_imported:
1540
+ cfg_dict = {
1541
+ key: value
1542
+ for key, value in cfg_dict.items()
1543
+ if key not in self._imported_names
1544
+ }
1545
+ return cfg_dict
segformer_plusplus/configs/config/lazy.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import Any, Optional, Union, Type
3
+ from collections import abc
4
+
5
+
6
+ class LazyObject:
7
+ """LazyObject is used to lazily initialize the imported module during
8
+ parsing the configuration file.
9
+
10
+ During parsing process, the syntax like:
11
+
12
+ Examples:
13
+ >>> import torch.nn as nn
14
+ >>> from mmdet.models import RetinaNet
15
+ >>> import mmcls.models
16
+ >>> import mmcls.datasets
17
+ >>> import mmcls
18
+
19
+ Will be parsed as:
20
+
21
+ Examples:
22
+ >>> # import torch.nn as nn
23
+ >>> nn = lazyObject('torch.nn')
24
+ >>> # from mmdet.models import RetinaNet
25
+ >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet')
26
+ >>> # import mmcls.models; import mmcls.datasets; import mmcls
27
+ >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models'])
28
+
29
+ ``LazyObject`` records all module information and will be further
30
+ referenced by the configuration file.
31
+
32
+ Args:
33
+ module (str or list or tuple): The module name to be imported.
34
+ imported (str, optional): The imported module name. Defaults to None.
35
+ location (str, optional): The filename and line number of the imported
36
+ module statement happened.
37
+ """
38
+
39
+ def __init__(self,
40
+ module: Union[str, list, tuple],
41
+ imported: Optional[str] = None,
42
+ location: Optional[str] = None):
43
+ if not isinstance(module, str) and not is_seq_of(module, str):
44
+ raise TypeError('module should be `str`, `list`, or `tuple`'
45
+ f'but got {type(module)}, this might be '
46
+ 'a bug, please report it')
47
+ self._module: Union[str, list, tuple] = module
48
+
49
+ if not isinstance(imported, str) and imported is not None:
50
+ raise TypeError('imported should be `str` or None, but got '
51
+ f'{type(imported)}, this might be '
52
+ 'a bug , please report it')
53
+ self._imported = imported
54
+ self.location = location
55
+
56
+ def build(self) -> Any:
57
+ """Return imported object.
58
+
59
+ Returns:
60
+ Any: Imported object
61
+ """
62
+ if isinstance(self._module, str):
63
+ try:
64
+ module = importlib.import_module(self._module)
65
+ except Exception as e:
66
+ raise type(e)(f'Failed to import {self._module} '
67
+ f'in {self.location} for {e}')
68
+
69
+ if self._imported is not None:
70
+ if hasattr(module, self._imported):
71
+ module = getattr(module, self._imported)
72
+ else:
73
+ raise ImportError(
74
+ f'Failed to import {self._imported} '
75
+ f'from {self._module} in {self.location}')
76
+
77
+ return module
78
+ else:
79
+ try:
80
+ for module in self._module:
81
+ importlib.import_module(module) # type: ignore
82
+ module_name = self._module[0].split('.')[0]
83
+ return importlib.import_module(module_name)
84
+ except Exception as e:
85
+ raise type(e)(f'Failed to import {self.module} '
86
+ f'in {self.location} for {e}')
87
+
88
+ @property
89
+ def module(self):
90
+ if isinstance(self._module, str):
91
+ return self._module
92
+ return self._module[0].split('.')[0]
93
+
94
+ def __call__(self, *args, **kwargs):
95
+ raise RuntimeError()
96
+
97
+ def __deepcopy__(self, memo):
98
+ return LazyObject(self._module, self._imported, self.location)
99
+
100
+ def __getattr__(self, name):
101
+ # Cannot locate the line number of the getting attribute.
102
+ # Therefore only record the filename.
103
+ if self.location is not None:
104
+ location = self.location.split(', line')[0]
105
+ else:
106
+ location = self.location
107
+ return LazyAttr(name, self, location)
108
+
109
+ def __str__(self) -> str:
110
+ if self._imported is not None:
111
+ return self._imported
112
+ return self.module
113
+
114
+ __repr__ = __str__
115
+
116
+ # `pickle.dump` will try to get the `__getstate__` and `__setstate__`
117
+ # methods of the dumped object. If these two methods are not defined,
118
+ # LazyObject will return a `__getstate__` LazyObject` or `__setstate__`
119
+ # LazyObject.
120
+ def __getstate__(self):
121
+ return self.__dict__
122
+
123
+ def __setstate__(self, state):
124
+ self.__dict__ = state
125
+
126
+
127
+ class LazyAttr:
128
+ """The attribute of the LazyObject.
129
+
130
+ When parsing the configuration file, the imported syntax will be
131
+ parsed as the assignment ``LazyObject``. During the subsequent parsing
132
+ process, users may reference the attributes of the LazyObject.
133
+ To ensure that these attributes also contain information needed to
134
+ reconstruct the attribute itself, LazyAttr was introduced.
135
+
136
+ Examples:
137
+ >>> models = LazyObject(['mmdet.models'])
138
+ >>> model = dict(type=models.RetinaNet)
139
+ >>> print(type(model['type'])) # <class 'mmengine.config.lazy.LazyAttr'>
140
+ >>> print(model['type'].build()) # <class 'mmdet.models.detectors.retinanet.RetinaNet'>
141
+ """ # noqa: E501
142
+
143
+ def __init__(self,
144
+ name: str,
145
+ source: Union['LazyObject', 'LazyAttr'],
146
+ location=None):
147
+ self.name = name
148
+ self.source: Union[LazyAttr, LazyObject] = source
149
+
150
+ if isinstance(self.source, LazyObject):
151
+ if isinstance(self.source._module, str):
152
+ if self.source._imported is None:
153
+ # source code:
154
+ # from xxx.yyy import zzz
155
+ # equivalent code:
156
+ # zzz = LazyObject('xxx.yyy', 'zzz')
157
+ # The source code of get attribute:
158
+ # eee = zzz.eee
159
+ # Then, `eee._module` should be "xxx.yyy.zzz"
160
+ self._module = self.source._module
161
+ else:
162
+ # source code:
163
+ # import xxx.yyy as zzz
164
+ # equivalent code:
165
+ # zzz = LazyObject('xxx.yyy')
166
+ # The source code of get attribute:
167
+ # eee = zzz.eee
168
+ # Then, `eee._module` should be "xxx.yyy"
169
+ self._module = f'{self.source._module}.{self.source}'
170
+ else:
171
+ # The source code of LazyObject should be
172
+ # 1. import xxx.yyy
173
+ # 2. import xxx.zzz
174
+ # Equivalent to
175
+ # xxx = LazyObject(['xxx.yyy', 'xxx.zzz'])
176
+
177
+ # The source code of LazyAttr should be
178
+ # eee = xxx.eee
179
+ # Then, eee._module = xxx
180
+ self._module = str(self.source)
181
+ elif isinstance(self.source, LazyAttr):
182
+ # 1. import xxx
183
+ # 2. zzz = xxx.yyy.zzz
184
+
185
+ # Equivalent to:
186
+ # xxx = LazyObject('xxx')
187
+ # zzz = xxx.yyy.zzz
188
+ # zzz._module = xxx.yyy._module + zzz.name
189
+ self._module = f'{self.source._module}.{self.source.name}'
190
+ self.location = location
191
+
192
+ @property
193
+ def module(self):
194
+ return self._module
195
+
196
+ def __call__(self, *args, **kwargs: Any) -> Any:
197
+ raise RuntimeError()
198
+
199
+ def __getattr__(self, name: str) -> 'LazyAttr':
200
+ return LazyAttr(name, self)
201
+
202
+ def __deepcopy__(self, memo):
203
+ return LazyAttr(self.name, self.source)
204
+
205
+ def build(self) -> Any:
206
+ """Return the attribute of the imported object.
207
+
208
+ Returns:
209
+ Any: attribute of the imported object.
210
+ """
211
+ obj = self.source.build()
212
+ try:
213
+ return getattr(obj, self.name)
214
+ except AttributeError:
215
+ raise ImportError(f'Failed to import {self.module}.{self.name} in '
216
+ f'{self.location}')
217
+ except ImportError as e:
218
+ raise e
219
+
220
+ def __str__(self) -> str:
221
+ return self.name
222
+
223
+ __repr__ = __str__
224
+
225
+ # `pickle.dump` will try to get the `__getstate__` and `__setstate__`
226
+ # methods of the dumped object. If these two methods are not defined,
227
+ # LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__`
228
+ # LazyAttr.
229
+ def __getstate__(self):
230
+ return self.__dict__
231
+
232
+ def __setstate__(self, state):
233
+ self.__dict__ = state
234
+
235
+
236
+ def is_seq_of(seq: Any,
237
+ expected_type: Union[Type, tuple],
238
+ seq_type: Optional[Type] = None) -> bool:
239
+ """Check whether it is a sequence of some type.
240
+
241
+ Args:
242
+ seq (Sequence): The sequence to be checked.
243
+ expected_type (type or tuple): Expected type of sequence items.
244
+ seq_type (type, optional): Expected sequence type. Defaults to None.
245
+
246
+ Returns:
247
+ bool: Return True if ``seq`` is valid else False.
248
+
249
+ Examples:
250
+ >>> from mmengine.utils import is_seq_of
251
+ >>> seq = ['a', 'b', 'c']
252
+ >>> is_seq_of(seq, str)
253
+ True
254
+ >>> is_seq_of(seq, int)
255
+ False
256
+ """
257
+ if seq_type is None:
258
+ exp_seq_type = abc.Sequence
259
+ else:
260
+ assert isinstance(seq_type, type)
261
+ exp_seq_type = seq_type
262
+ if not isinstance(seq, exp_seq_type):
263
+ return False
264
+ for item in seq:
265
+ if not isinstance(item, expected_type):
266
+ return False
267
+ return True
segformer_plusplus/configs/config/utils.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os.path as osp
3
+ import re
4
+ import sys
5
+ import warnings
6
+ from collections import defaultdict
7
+ from importlib.util import find_spec
8
+ from typing import List, Optional, Tuple, Union
9
+ from importlib import import_module as real_import_module
10
+ import json
11
+ import pickle
12
+ from pathlib import Path
13
+ from mim.utils import package2module
14
+
15
+ import yaml
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable))
20
+ SYSTEM_PYTHON_PREFIX = '/usr/lib/python'
21
+
22
+ MODULE2PACKAGE = {
23
+ 'mmcls': 'mmcls',
24
+ 'mmdet': 'mmdet',
25
+ 'mmdet3d': 'mmdet3d',
26
+ 'mmseg': 'mmsegmentation',
27
+ 'mmaction': 'mmaction2',
28
+ 'mmtrack': 'mmtrack',
29
+ 'mmpose': 'mmpose',
30
+ 'mmedit': 'mmedit',
31
+ 'mmocr': 'mmocr',
32
+ 'mmgen': 'mmgen',
33
+ 'mmfewshot': 'mmfewshot',
34
+ 'mmrazor': 'mmrazor',
35
+ 'mmflow': 'mmflow',
36
+ 'mmhuman3d': 'mmhuman3d',
37
+ 'mmrotate': 'mmrotate',
38
+ 'mmselfsup': 'mmselfsup',
39
+ 'mmyolo': 'mmyolo',
40
+ 'mmpretrain': 'mmpretrain',
41
+ 'mmagic': 'mmagic',
42
+ }
43
+
44
+ # PKG2PROJECT is not a proper name to represent the mapping between module name
45
+ # (module import from) and package name (used by pip install). Therefore,
46
+ # PKG2PROJECT will be deprecated and this alias will only be kept until
47
+ # MMEngine v1.0.0
48
+ PKG2PROJECT = MODULE2PACKAGE
49
+
50
+
51
+ class ConfigParsingError(RuntimeError):
52
+ """Raise error when failed to parse pure Python style config files."""
53
+
54
+
55
+ def _get_cfg_metainfo(package_path: str, cfg_path: str) -> dict:
56
+ """Get target meta information from all 'metafile.yml' defined in `mode-
57
+ index.yml` of external package.
58
+
59
+ Args:
60
+ package_path (str): Path of external package.
61
+ cfg_path (str): Name of experiment config.
62
+
63
+ Returns:
64
+ dict: Meta information of target experiment.
65
+ """
66
+ meta_index_path = osp.join(package_path, '.mim', 'model-index.yml')
67
+ meta_index = OmegaConf.to_container(OmegaConf.load(meta_index_path), resolve=True)
68
+ cfg_dict = dict()
69
+ for meta_path in meta_index['Import']:
70
+ meta_path = osp.join(package_path, '.mim', meta_path)
71
+ cfg_meta = OmegaConf.to_container(OmegaConf.load(meta_path), resolve=True)
72
+ for model_cfg in cfg_meta['Models']:
73
+ if 'Config' not in model_cfg:
74
+ warnings.warn(f'There is not `Config` define in {model_cfg}')
75
+ continue
76
+ cfg_name = model_cfg['Config'].partition('/')[-1]
77
+ # Some config could have multiple weights, we only pick the
78
+ # first one.
79
+ if cfg_name in cfg_dict:
80
+ continue
81
+ cfg_dict[cfg_name] = model_cfg
82
+ if cfg_path not in cfg_dict:
83
+ raise ValueError(f'Expected configs: {cfg_dict.keys()}, but got '
84
+ f'{cfg_path}')
85
+ return cfg_dict[cfg_path]
86
+
87
+
88
+ def _get_external_cfg_path(package_path: str, cfg_file: str) -> str:
89
+ """Get config path of external package.
90
+
91
+ Args:
92
+ package_path (str): Path of external package.
93
+ cfg_file (str): Name of experiment config.
94
+
95
+ Returns:
96
+ str: Absolute config path from external package.
97
+ """
98
+ cfg_file = cfg_file.split('.')[0]
99
+ model_cfg = _get_cfg_metainfo(package_path, cfg_file)
100
+ cfg_path = osp.join(package_path, model_cfg['Config'])
101
+ check_file_exist(cfg_path)
102
+ return cfg_path
103
+
104
+
105
+ def _get_external_cfg_base_path(package_path: str, cfg_name: str) -> str:
106
+ """Get base config path of external package.
107
+
108
+ Args:
109
+ package_path (str): Path of external package.
110
+ cfg_name (str): External relative config path with 'package::'.
111
+
112
+ Returns:
113
+ str: Absolute config path from external package.
114
+ """
115
+ cfg_path = osp.join(package_path, '.mim', 'configs', cfg_name)
116
+ check_file_exist(cfg_path)
117
+ return cfg_path
118
+
119
+
120
+ def _get_package_and_cfg_path(cfg_path: str) -> Tuple[str, str]:
121
+ """Get package name and relative config path.
122
+
123
+ Args:
124
+ cfg_path (str): External relative config path with 'package::'.
125
+
126
+ Returns:
127
+ Tuple[str, str]: Package name and config path.
128
+ """
129
+ if re.match(r'\w*::\w*/\w*', cfg_path) is None:
130
+ raise ValueError(
131
+ '`_get_package_and_cfg_path` is used for get external package, '
132
+ 'please specify the package name and relative config path, just '
133
+ 'like `mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py`')
134
+ package_cfg = cfg_path.split('::')
135
+ if len(package_cfg) > 2:
136
+ raise ValueError('`::` should only be used to separate package and '
137
+ 'config name, but found multiple `::` in '
138
+ f'{cfg_path}')
139
+ package, cfg_path = package_cfg
140
+ assert package in MODULE2PACKAGE, (
141
+ f'mmengine does not support to load {package} config.')
142
+ package = MODULE2PACKAGE[package]
143
+ return package, cfg_path
144
+
145
+
146
+ class RemoveAssignFromAST(ast.NodeTransformer):
147
+ """Remove Assign node if the target's name match the key.
148
+
149
+ Args:
150
+ key (str): The target name of the Assign node.
151
+ """
152
+
153
+ def __init__(self, key):
154
+ self.key = key
155
+
156
+ def visit_Assign(self, node):
157
+ if (isinstance(node.targets[0], ast.Name)
158
+ and node.targets[0].id == self.key):
159
+ return None
160
+ else:
161
+ return node
162
+
163
+
164
+ def _is_builtin_module(module_name: str) -> bool:
165
+ """Check if a module is a built-in module.
166
+
167
+ Arg:
168
+ module_name: name of module.
169
+ """
170
+ if module_name.startswith('.'):
171
+ return False
172
+ if module_name.startswith('mmengine.config'):
173
+ return True
174
+ if module_name in sys.builtin_module_names:
175
+ return True
176
+ spec = find_spec(module_name.split('.')[0])
177
+ # Module not found
178
+ if spec is None:
179
+ return False
180
+ origin_path = getattr(spec, 'origin', None)
181
+ if origin_path is None:
182
+ return False
183
+ origin_path = osp.abspath(origin_path)
184
+ if ('site-package' in origin_path or 'dist-package' in origin_path
185
+ or not origin_path.startswith(
186
+ (PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX))):
187
+ return False
188
+ else:
189
+ return True
190
+
191
+
192
+ class ImportTransformer(ast.NodeTransformer):
193
+ """Convert the import syntax to the assignment of
194
+ :class:`mmengine.config.LazyObject` and preload the base variable before
195
+ parsing the configuration file.
196
+
197
+ Since you are already looking at this part of the code, I believe you must
198
+ be interested in the mechanism of the ``lazy_import`` feature of
199
+ :class:`Config`. In this docstring, we will dive deeper into its
200
+ principles.
201
+
202
+ Most of OpenMMLab users maybe bothered with that:
203
+
204
+ * In most of popular IDEs, they cannot navigate to the source code in
205
+ configuration file
206
+ * In most of popular IDEs, they cannot jump to the base file in current
207
+ configuration file, which is much painful when the inheritance
208
+ relationship is complex.
209
+
210
+ In order to solve this problem, we introduce the ``lazy_import`` mode.
211
+
212
+ A very intuitive idea for solving this problem is to import the module
213
+ corresponding to the "type" field using the ``import`` syntax. Similarly,
214
+ we can also ``import`` base file.
215
+
216
+ However, this approach has a significant drawback. It requires triggering
217
+ the import logic to parse the configuration file, which can be
218
+ time-consuming. Additionally, it implies downloading numerous dependencies
219
+ solely for the purpose of parsing the configuration file.
220
+ However, it's possible that only a portion of the config will actually be
221
+ used. For instance, the package used in the ``train_pipeline`` may not
222
+ be necessary for an evaluation task. Forcing users to download these
223
+ unused packages is not a desirable solution.
224
+
225
+ To avoid this problem, we introduce :class:`mmengine.config.LazyObject` and
226
+ :class:`mmengine.config.LazyAttr`. Before we proceed with further
227
+ explanations, you may refer to the documentation of these two modules to
228
+ gain an understanding of their functionalities.
229
+
230
+ Actually, one of the functions of ``ImportTransformer`` is to hack the
231
+ ``import`` syntax. It will replace the import syntax
232
+ (exclude import the base files) with the assignment of ``LazyObject``.
233
+
234
+ As for the import syntax of the base file, we cannot lazy import it since
235
+ we're eager to merge the fields of current file and base files. Therefore,
236
+ another function of the ``ImportTransformer`` is to collaborate with
237
+ ``Config._parse_lazy_import`` to parse the base files.
238
+
239
+ Args:
240
+ global_dict (dict): The global dict of the current configuration file.
241
+ If we divide ordinary Python syntax into two parts, namely the
242
+ import section and the non-import section (assuming a simple case
243
+ with imports at the beginning and the rest of the code following),
244
+ the variables generated by the import statements are stored in
245
+ global variables for subsequent code use. In this context,
246
+ the ``global_dict`` represents the global variables required when
247
+ executing the non-import code. ``global_dict`` will be filled
248
+ during visiting the parsed code.
249
+ base_dict (dict): All variables defined in base files.
250
+
251
+ Examples:
252
+ >>> from mmengine.config import read_base
253
+ >>>
254
+ >>>
255
+ >>> with read_base():
256
+ >>> from .._base_.default_runtime import *
257
+ >>> from .._base_.datasets.coco_detection import dataset
258
+
259
+ In this case, the base_dict will be:
260
+
261
+ Examples:
262
+ >>> base_dict = {
263
+ >>> '.._base_.default_runtime': ...
264
+ >>> '.._base_.datasets.coco_detection': dataset}
265
+
266
+ and `global_dict` will be updated like this:
267
+
268
+ Examples:
269
+ >>> global_dict.update(base_dict['.._base_.default_runtime']) # `import *` means update all data
270
+ >>> global_dict.update(dataset=base_dict['.._base_.datasets.coco_detection']['dataset']) # only update `dataset`
271
+ """ # noqa: E501
272
+
273
+ def __init__(self,
274
+ global_dict: dict,
275
+ base_dict: Optional[dict] = None,
276
+ filename: Optional[str] = None):
277
+ self.base_dict = base_dict if base_dict is not None else {}
278
+ self.global_dict = global_dict
279
+ # In Windows, the filename could be like this:
280
+ # "C:\\Users\\runneradmin\\AppData\\Local\\"
281
+ # Although it has been an raw string, ast.parse will firstly escape
282
+ # it as the executed code:
283
+ # "C:\Users\runneradmin\AppData\Local\\\"
284
+ # As you see, the `\U` will be treated as a part of
285
+ # the escape sequence during code parsing, leading to an
286
+ # parsing error
287
+ # Here we use `encode('unicode_escape').decode()` for double escaping
288
+ if isinstance(filename, str):
289
+ filename = filename.encode('unicode_escape').decode()
290
+ self.filename = filename
291
+ self.imported_obj: set = set()
292
+ super().__init__()
293
+
294
+ def visit_ImportFrom(
295
+ self, node: ast.ImportFrom
296
+ ) -> Optional[Union[List[ast.Assign], ast.ImportFrom]]:
297
+ """Hack the ``from ... import ...`` syntax and update the global_dict.
298
+
299
+ Examples:
300
+ >>> from mmdet.models import RetinaNet
301
+
302
+ Will be parsed as:
303
+
304
+ Examples:
305
+ >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet')
306
+
307
+ ``global_dict`` will also be updated by ``base_dict`` as the
308
+ class docstring says.
309
+
310
+ Args:
311
+ node (ast.AST): The node of the current import statement.
312
+
313
+ Returns:
314
+ Optional[List[ast.Assign]]: There three cases:
315
+
316
+ * If the node is a statement of importing base files.
317
+ None will be returned.
318
+ * If the node is a statement of importing a builtin module,
319
+ node will be directly returned
320
+ * Otherwise, it will return the assignment statements of
321
+ ``LazyObject``.
322
+ """
323
+ # Built-in modules will not be parsed as LazyObject
324
+ module = f'{node.level*"."}{node.module}'
325
+ if _is_builtin_module(module):
326
+ # Make sure builtin module will be added into `self.imported_obj`
327
+ for alias in node.names:
328
+ if alias.asname is not None:
329
+ self.imported_obj.add(alias.asname)
330
+ elif alias.name == '*':
331
+ raise ConfigParsingError(
332
+ 'Cannot import * from non-base config')
333
+ else:
334
+ self.imported_obj.add(alias.name)
335
+ return node
336
+
337
+ if module in self.base_dict:
338
+ for alias_node in node.names:
339
+ if alias_node.name == '*':
340
+ self.global_dict.update(self.base_dict[module])
341
+ return None
342
+ if alias_node.asname is not None:
343
+ base_key = alias_node.asname
344
+ else:
345
+ base_key = alias_node.name
346
+ self.global_dict[base_key] = self.base_dict[module][
347
+ alias_node.name]
348
+ return None
349
+
350
+ nodes: List[ast.Assign] = []
351
+ for alias_node in node.names:
352
+ # `ast.alias` has lineno attr after Python 3.10,
353
+ if hasattr(alias_node, 'lineno'):
354
+ lineno = alias_node.lineno
355
+ else:
356
+ lineno = node.lineno
357
+ if alias_node.name == '*':
358
+ # TODO: If users import * from a non-config module, it should
359
+ # fallback to import the real module and raise a warning to
360
+ # remind users the real module will be imported which will slow
361
+ # down the parsing speed.
362
+ raise ConfigParsingError(
363
+ 'Illegal syntax in config! `from xxx import *` is not '
364
+ 'allowed to appear outside the `if base:` statement')
365
+ elif alias_node.asname is not None:
366
+ # case1:
367
+ # from mmengine.dataset import BaseDataset as Dataset ->
368
+ # Dataset = LazyObject('mmengine.dataset', 'BaseDataset')
369
+ code = f'{alias_node.asname} = LazyObject("{module}", "{alias_node.name}", "{self.filename}, line {lineno}")' # noqa: E501
370
+ self.imported_obj.add(alias_node.asname)
371
+ else:
372
+ # case2:
373
+ # from mmengine.model import BaseModel
374
+ # BaseModel = LazyObject('mmengine.model', 'BaseModel')
375
+ code = f'{alias_node.name} = LazyObject("{module}", "{alias_node.name}", "{self.filename}, line {lineno}")' # noqa: E501
376
+ self.imported_obj.add(alias_node.name)
377
+ try:
378
+ nodes.append(ast.parse(code).body[0]) # type: ignore
379
+ except Exception as e:
380
+ raise ConfigParsingError(
381
+ f'Cannot import {alias_node} from {module}'
382
+ '1. Cannot import * from 3rd party lib in the config '
383
+ 'file\n'
384
+ '2. Please check if the module is a base config which '
385
+ 'should be added to `_base_`\n') from e
386
+ return nodes
387
+
388
+ def visit_Import(self, node) -> Union[ast.Assign, ast.Import]:
389
+ """Work with ``_gather_abs_import_lazyobj`` to hack the ``import ...``
390
+ syntax.
391
+
392
+ Examples:
393
+ >>> import mmcls.models
394
+ >>> import mmcls.datasets
395
+ >>> import mmcls
396
+
397
+ Will be parsed as:
398
+
399
+ Examples:
400
+ >>> # import mmcls.models; import mmcls.datasets; import mmcls
401
+ >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models'])
402
+
403
+ Args:
404
+ node (ast.AST): The node of the current import statement.
405
+
406
+ Returns:
407
+ ast.Assign: If the import statement is ``import ... as ...``,
408
+ ast.Assign will be returned, otherwise node will be directly
409
+ returned.
410
+ """
411
+ # For absolute import like: `import mmdet.configs as configs`.
412
+ # It will be parsed as:
413
+ # configs = LazyObject('mmdet.configs')
414
+ # For absolute import like:
415
+ # `import mmdet.configs`
416
+ # `import mmdet.configs.default_runtime`
417
+ # This will be parsed as
418
+ # mmdet = LazyObject(['mmdet.configs.default_runtime', 'mmdet.configs])
419
+ # However, visit_Import cannot gather other import information, so
420
+ # `_gather_abs_import_LazyObject` will gather all import information
421
+ # from the same module and construct the LazyObject.
422
+ alias_list = node.names
423
+ assert len(alias_list) == 1, (
424
+ 'Illegal syntax in config! import multiple modules in one line is '
425
+ 'not supported')
426
+ # TODO Support multiline import
427
+ alias = alias_list[0]
428
+ if alias.asname is not None:
429
+ self.imported_obj.add(alias.asname)
430
+ if _is_builtin_module(alias.name.split('.')[0]):
431
+ return node
432
+ return ast.parse( # type: ignore
433
+ f'{alias.asname} = LazyObject('
434
+ f'"{alias.name}",'
435
+ f'location="{self.filename}, line {node.lineno}")').body[0]
436
+ return node
437
+
438
+
439
+ def _gather_abs_import_lazyobj(tree: ast.Module,
440
+ filename: Optional[str] = None):
441
+ """Experimental implementation of gathering absolute import information."""
442
+ if isinstance(filename, str):
443
+ filename = filename.encode('unicode_escape').decode()
444
+ imported = defaultdict(list)
445
+ abs_imported = set()
446
+ new_body: List[ast.stmt] = []
447
+ # module2node is used to get lineno when Python < 3.10
448
+ module2node: dict = dict()
449
+ for node in tree.body:
450
+ if isinstance(node, ast.Import):
451
+ for alias in node.names:
452
+ # Skip converting built-in module to LazyObject
453
+ if _is_builtin_module(alias.name):
454
+ new_body.append(node)
455
+ continue
456
+ module = alias.name.split('.')[0]
457
+ module2node.setdefault(module, node)
458
+ imported[module].append(alias)
459
+ continue
460
+ new_body.append(node)
461
+
462
+ for key, value in imported.items():
463
+ names = [_value.name for _value in value]
464
+ if hasattr(value[0], 'lineno'):
465
+ lineno = value[0].lineno
466
+ else:
467
+ lineno = module2node[key].lineno
468
+ lazy_module_assign = ast.parse(
469
+ f'{key} = LazyObject({names}, location="{filename}, line {lineno}")' # noqa: E501
470
+ ) # noqa: E501
471
+ abs_imported.add(key)
472
+ new_body.insert(0, lazy_module_assign.body[0])
473
+ tree.body = new_body
474
+ return tree, abs_imported
475
+
476
+
477
+ def get_installed_path(package: str) -> str:
478
+ """Get installed path of package.
479
+
480
+ Args:
481
+ package (str): Name of package.
482
+
483
+ Example:
484
+ >>> get_installed_path('mmcls')
485
+ >>> '.../lib/python3.7/site-packages/mmcls'
486
+ """
487
+ import importlib.util
488
+
489
+ from pkg_resources import DistributionNotFound, get_distribution
490
+
491
+ # if the package name is not the same as module name, module name should be
492
+ # inferred. For example, mmcv-full is the package name, but mmcv is module
493
+ # name. If we want to get the installed path of mmcv-full, we should concat
494
+ # the pkg.location and module name
495
+ try:
496
+ pkg = get_distribution(package)
497
+ except DistributionNotFound as e:
498
+ # if the package is not installed, package path set in PYTHONPATH
499
+ # can be detected by `find_spec`
500
+ spec = importlib.util.find_spec(package)
501
+ if spec is not None:
502
+ if spec.origin is not None:
503
+ return osp.dirname(spec.origin)
504
+ else:
505
+ # `get_installed_path` cannot get the installed path of
506
+ # namespace packages
507
+ raise RuntimeError(
508
+ f'{package} is a namespace package, which is invalid '
509
+ 'for `get_install_path`')
510
+ else:
511
+ raise e
512
+
513
+ possible_path = osp.join(pkg.location, package) # type: ignore
514
+ if osp.exists(possible_path):
515
+ return possible_path
516
+ else:
517
+ return osp.join(pkg.location, package2module(package)) # type: ignore
518
+
519
+
520
+ def import_modules_from_strings(imports, allow_failed_imports=False):
521
+ """Import modules from the given list of strings.
522
+
523
+ Args:
524
+ imports (list | str | None): The given module names to be imported.
525
+ allow_failed_imports (bool): If True, the failed imports will return
526
+ None. Otherwise, an ImportError is raise. Defaults to False.
527
+
528
+ Returns:
529
+ list[module] | module | None: The imported modules.
530
+
531
+ Examples:
532
+ >>> osp, sys = import_modules_from_strings(
533
+ ... ['os.path', 'sys'])
534
+ >>> import os.path as osp_
535
+ >>> import sys as sys_
536
+ >>> assert osp == osp_
537
+ >>> assert sys == sys_
538
+ """
539
+ if not imports:
540
+ return
541
+ single_import = False
542
+ if isinstance(imports, str):
543
+ single_import = True
544
+ imports = [imports]
545
+ if not isinstance(imports, list):
546
+ raise TypeError(
547
+ f'custom_imports must be a list but got type {type(imports)}')
548
+ imported = []
549
+ for imp in imports:
550
+ if not isinstance(imp, str):
551
+ raise TypeError(
552
+ f'{imp} is of type {type(imp)} and cannot be imported.')
553
+ try:
554
+ imported_tmp = import_module(imp)
555
+ except ImportError:
556
+ if allow_failed_imports:
557
+ warnings.warn(f'{imp} failed to import and is ignored.',
558
+ UserWarning)
559
+ imported_tmp = None
560
+ else:
561
+ raise ImportError(f'Failed to import {imp}')
562
+ imported.append(imported_tmp)
563
+ if single_import:
564
+ imported = imported[0]
565
+ return imported
566
+
567
+
568
+ def import_module(name, package=None):
569
+ """Import a module, optionally supporting relative imports."""
570
+ return real_import_module(name, package)
571
+
572
+
573
+ def is_installed(package: str) -> bool:
574
+ """Check package whether installed.
575
+
576
+ Args:
577
+ package (str): Name of package to be checked.
578
+ """
579
+ # When executing `import mmengine.runner`,
580
+ # pkg_resources will be imported and it takes too much time.
581
+ # Therefore, import it in function scope to save time.
582
+ import importlib.util
583
+ import pkg_resources
584
+ from pkg_resources import get_distribution
585
+
586
+ # refresh the pkg_resources
587
+ # more datails at https://github.com/pypa/setuptools/issues/373
588
+ importlib.reload(pkg_resources)
589
+ try:
590
+ get_distribution(package)
591
+ return True
592
+ except pkg_resources.DistributionNotFound:
593
+ spec = importlib.util.find_spec(package)
594
+ if spec is None:
595
+ return False
596
+ elif spec.origin is not None:
597
+ return True
598
+ else:
599
+ return False
600
+
601
+
602
+ def dump(obj, file=None, file_format=None, **kwargs):
603
+ """Dump data to json/yaml/pickle strings or files (mmengine-like replacement)."""
604
+ if isinstance(file, Path):
605
+ file = str(file)
606
+
607
+ # Guess file format if not explicitly given
608
+ if file_format is None:
609
+ if isinstance(file, str):
610
+ file_format = file.split('.')[-1].lower()
611
+ elif file is None:
612
+ raise ValueError("file_format must be specified if file is None")
613
+
614
+ if file_format not in ['json', 'yaml', 'yml', 'pkl', 'pickle']:
615
+ raise TypeError(f"Unsupported file format: {file_format}")
616
+
617
+ # Convert YAML extension
618
+ if file_format == 'yml':
619
+ file_format = 'yaml'
620
+ if file_format == 'pickle':
621
+ file_format = 'pkl'
622
+
623
+ # Handle output to string
624
+ if file is None:
625
+ if file_format == 'json':
626
+ return json.dumps(obj, indent=4, **kwargs)
627
+ elif file_format == 'yaml':
628
+ return yaml.dump(obj, **kwargs)
629
+ elif file_format == 'pkl':
630
+ return pickle.dumps(obj, **kwargs)
631
+
632
+ # Handle output to file
633
+ mode = 'w' if file_format in ['json', 'yaml'] else 'wb'
634
+ with open(file, mode, encoding='utf-8' if 'b' not in mode else None) as f:
635
+ if file_format == 'json':
636
+ json.dump(obj, f, indent=4, **kwargs)
637
+ elif file_format == 'yaml':
638
+ yaml.dump(obj, f, **kwargs)
639
+ elif file_format == 'pkl':
640
+ pickle.dump(obj, f, **kwargs)
641
+
642
+ return True
643
+
644
+
645
+ def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
646
+ if not osp.isfile(filename):
647
+ raise FileNotFoundError(msg_tmpl.format(filename))
segformer_plusplus/configs/segformer_mit_b0.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
2
+ backbone = dict(
3
+ type='MixVisionTransformer',
4
+ in_channels=3,
5
+ embed_dims=32,
6
+ num_stages=4,
7
+ num_layers=[2, 2, 2, 2],
8
+ num_heads=[1, 2, 5, 8],
9
+ patch_sizes=[7, 3, 3, 3],
10
+ sr_ratios=[8, 4, 2, 1],
11
+ out_indices=(0, 1, 2, 3),
12
+ mlp_ratio=4,
13
+ qkv_bias=True,
14
+ drop_rate=0.0,
15
+ attn_drop_rate=0.0,
16
+ drop_path_rate=0.1
17
+ )
18
+ decode_head = dict(
19
+ type='SegformerHead',
20
+ in_channels=[32, 64, 160, 256],
21
+ in_index=[0, 1, 2, 3],
22
+ channels=256,
23
+ dropout_ratio=0.1,
24
+ out_channels=19,
25
+ norm_cfg=norm_cfg,
26
+ align_corners=False,
27
+ interpolate_mode='bilinear'
28
+ )
segformer_plusplus/configs/segformer_mit_b1.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b0.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ )
6
+ decode_head = dict(
7
+ in_channels=[64, 128, 320, 512]
8
+ )
segformer_plusplus/configs/segformer_mit_b2.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b1.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ num_layers=[3, 4, 6, 3]
6
+ )
segformer_plusplus/configs/segformer_mit_b3.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b1.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ num_layers=[3, 4, 18, 3]
6
+ )
segformer_plusplus/configs/segformer_mit_b4.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b1.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ num_layers=[3, 8, 27, 3]
6
+ )
segformer_plusplus/configs/segformer_mit_b5.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _base_ = ['./segformer_mit_b1.py']
2
+
3
+ backbone = dict(
4
+ embed_dims=64,
5
+ num_layers=[3, 6, 40, 3]
6
+ )
segformer_plusplus/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __all__ = []
segformer_plusplus/model/backbone/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .mit import MixVisionTransformer
2
+
3
+ __all__ = ['MixVisionTransformer']
segformer_plusplus/model/backbone/mit.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.checkpoint as cp
6
+ from tomesd.merge import bipartite_soft_matching_random2d
7
+
8
+ from ...utils import PatchEmbed
9
+ from ...utils import nchw_to_nlc, nlc_to_nchw
10
+ from ...utils import MODELS
11
+ from ...utils import Conv2d, build_activation_layer, build_norm_layer, build_dropout
12
+ from ..base_module import BaseModule, MultiheadAttention, ModuleList, Sequential
13
+ from ..weight_init import (constant_init, normal_init,
14
+ trunc_normal_init)
15
+
16
+
17
+ class MixFFN(BaseModule):
18
+ """An implementation of MixFFN of Segformer.
19
+
20
+ The differences between MixFFN & FFN:
21
+ 1. Use 1X1 Conv to replace Linear layer.
22
+ 2. Introduce 3X3 Conv to encode positional information.
23
+ Args:
24
+ embed_dims (int): The feature dimension. Same as
25
+ `MultiheadAttention`. Defaults: 256.
26
+ feedforward_channels (int): The hidden dimension of FFNs.
27
+ Defaults: 1024.
28
+ act_cfg (dict, optional): The activation config for FFNs.
29
+ Default: dict(type='ReLU')
30
+ ffn_drop (float, optional): Probability of an element to be
31
+ zeroed in FFN. Default 0.0.
32
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
33
+ when adding the shortcut.
34
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
35
+ Default: None.
36
+ """
37
+
38
+ def __init__(self,
39
+ embed_dims,
40
+ feedforward_channels,
41
+ act_cfg=dict(type='GELU'),
42
+ ffn_drop=0.,
43
+ dropout_layer=None,
44
+ init_cfg=None):
45
+ super().__init__(init_cfg)
46
+
47
+ self.embed_dims = embed_dims
48
+ self.feedforward_channels = feedforward_channels
49
+ self.act_cfg = act_cfg
50
+ self.activate = build_activation_layer(act_cfg)
51
+
52
+ in_channels = embed_dims
53
+ fc1 = Conv2d(
54
+ in_channels=in_channels,
55
+ out_channels=feedforward_channels,
56
+ kernel_size=1,
57
+ stride=1,
58
+ bias=True)
59
+ # 3x3 depth wise conv to provide positional encode information
60
+ pe_conv = Conv2d(
61
+ in_channels=feedforward_channels,
62
+ out_channels=feedforward_channels,
63
+ kernel_size=3,
64
+ stride=1,
65
+ padding=(3 - 1) // 2,
66
+ bias=True,
67
+ groups=feedforward_channels)
68
+ fc2 = Conv2d(
69
+ in_channels=feedforward_channels,
70
+ out_channels=in_channels,
71
+ kernel_size=1,
72
+ stride=1,
73
+ bias=True)
74
+ drop = nn.Dropout(ffn_drop)
75
+ layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
76
+ self.layers = Sequential(*layers)
77
+ self.dropout_layer = build_dropout(
78
+ dropout_layer) if dropout_layer else torch.nn.Identity()
79
+
80
+ def forward(self, x, hw_shape, identity=None):
81
+ out = nlc_to_nchw(x, hw_shape)
82
+ out = self.layers(out)
83
+ out = nchw_to_nlc(out)
84
+ if identity is None:
85
+ identity = x
86
+ return identity + self.dropout_layer(out)
87
+
88
+
89
+ class EfficientMultiheadAttention(MultiheadAttention):
90
+ """An implementation of Efficient Multi-head Attention of Segformer.
91
+
92
+ This module is modified from MultiheadAttention which is a module from
93
+ mmcv.cnn.bricks.transformer.
94
+ Args:
95
+ embed_dims (int): The embedding dimension.
96
+ num_heads (int): Parallel attention heads.
97
+ attn_drop (float): A Dropout layer on attn_output_weights.
98
+ Default: 0.0.
99
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
100
+ Default: 0.0.
101
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
102
+ when adding the shortcut. Default: None.
103
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
104
+ Default: None.
105
+ batch_first (bool): Key, Query and Value are shape of
106
+ (batch, n, embed_dim)
107
+ or (n, batch, embed_dim). Default: False.
108
+ qkv_bias (bool): enable bias for qkv if True. Default True.
109
+ norm_cfg (dict): Config dict for normalization layer.
110
+ Default: dict(type='LN').
111
+ sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
112
+ Attention of Segformer. Default: 1.
113
+ """
114
+
115
+ def __init__(self,
116
+ embed_dims,
117
+ num_heads,
118
+ attn_drop=0.,
119
+ proj_drop=0.,
120
+ dropout_layer=None,
121
+ init_cfg=None,
122
+ batch_first=True,
123
+ qkv_bias=False,
124
+ tome_cfg=dict(),
125
+ norm_cfg=dict(type='LN'),
126
+ sr_ratio=1):
127
+ super().__init__(
128
+ embed_dims,
129
+ num_heads,
130
+ attn_drop,
131
+ proj_drop,
132
+ dropout_layer=dropout_layer,
133
+ init_cfg=init_cfg,
134
+ batch_first=batch_first,
135
+ bias=qkv_bias)
136
+
137
+ self.q_mode = tome_cfg.get('q_mode')
138
+ self.kv_mode = tome_cfg.get('kv_mode')
139
+ self.tome_cfg = tome_cfg
140
+
141
+ self.sr_ratio = sr_ratio
142
+ if sr_ratio > 1:
143
+ self.sr = Conv2d(
144
+ in_channels=embed_dims,
145
+ out_channels=embed_dims,
146
+ kernel_size=sr_ratio,
147
+ stride=sr_ratio)
148
+ # The ret[0] of build_norm_layer is norm name.
149
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
150
+
151
+ def forward(self, x, hw_shape, identity=None):
152
+ x_q = x
153
+
154
+ if self.sr_ratio > 1:
155
+ x_kv = nlc_to_nchw(x, hw_shape)
156
+ x_kv = self.sr(x_kv)
157
+ x_kv = nchw_to_nlc(x_kv)
158
+ x_kv = self.norm(x_kv)
159
+ else:
160
+ x_kv = x
161
+
162
+ # 2D Neighbour Merging KV
163
+ if self.kv_mode == 'n2d':
164
+ kv_hw_shape = (int(hw_shape[0] / self.sr_ratio), int(hw_shape[1] / self.sr_ratio))
165
+ x_kv = nlc_to_nchw(x_kv, kv_hw_shape)
166
+ x_kv = torch.nn.functional.avg_pool2d(x_kv, kernel_size=self.tome_cfg['kv_s'],
167
+ stride=self.tome_cfg['kv_s'],
168
+ ceil_mode=True)
169
+ x_kv = nchw_to_nlc(x_kv)
170
+
171
+ # Bipartite Soft Matching (tomesd) KV
172
+ if self.kv_mode == 'bsm':
173
+ w_kv = int(hw_shape[1] / self.sr_ratio)
174
+ h_kv = int(hw_shape[0] / self.sr_ratio)
175
+ merge, unmerge = bipartite_soft_matching_random2d(metric=x_kv, w=w_kv, h=h_kv,
176
+ r=int(x_kv.size()[1] * self.tome_cfg['kv_r']),
177
+ sx=self.tome_cfg['kv_sx'], sy=self.tome_cfg['kv_sy'],
178
+ no_rand=True)
179
+ x_kv = merge(x_kv)
180
+
181
+ if identity is None:
182
+ identity = x_q
183
+
184
+ # 1D Neighbor Merging Q
185
+ if self.q_mode == 'n1d':
186
+ x_q = x_q.transpose(-2, -1)
187
+ x_q = torch.nn.functional.avg_pool1d(x_q, kernel_size=self.tome_cfg['q_s'],
188
+ stride=self.tome_cfg['q_s'],
189
+ ceil_mode=True)
190
+ x_q = x_q.transpose(-2, -1)
191
+
192
+ # 2D Neighbor Merging Q
193
+ if self.q_mode == 'n2d':
194
+ reduced_hw = (int(torch.ceil(torch.tensor(hw_shape[0] / self.tome_cfg['q_s'][0]))),
195
+ int(torch.ceil(torch.tensor(hw_shape[1] / self.tome_cfg['q_s'][1]))))
196
+ x_q = nlc_to_nchw(x_q, hw_shape)
197
+ x_q = torch.nn.functional.avg_pool2d(x_q, kernel_size=self.tome_cfg['q_s'],
198
+ stride=self.tome_cfg['q_s'],
199
+ ceil_mode=True)
200
+ x_q = nchw_to_nlc(x_q)
201
+
202
+ # Bipartite Soft Matching (tomesd) Q
203
+ if self.q_mode == 'bsm':
204
+ merge, unmerge = bipartite_soft_matching_random2d(metric=x_q, w=hw_shape[1], h=hw_shape[0],
205
+ r=int(x_q.size()[1] * self.tome_cfg['q_r']),
206
+ sx=self.tome_cfg['q_sx'], sy=self.tome_cfg['q_sy'],
207
+ no_rand=True)
208
+ x_q = merge(x_q)
209
+
210
+ # Because the dataflow('key', 'query', 'value') of
211
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
212
+ # embed_dims), We should adjust the shape of dataflow from
213
+ # batch_first (batch, num_query, embed_dims) to num_query_first
214
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
215
+ # from num_query_first to batch_first.
216
+
217
+ if self.batch_first:
218
+ x_q = x_q.transpose(0, 1)
219
+ x_kv = x_kv.transpose(0, 1)
220
+ out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
221
+ if self.batch_first:
222
+ out = out.transpose(0, 1)
223
+
224
+ # Unmerging BSM (tome+tomesd)
225
+ if self.q_mode == 'bsm':
226
+ out = unmerge(out)
227
+
228
+ # Unmerging 1D Neighbour Merging
229
+ if self.q_mode == 'n1d':
230
+ out = out.transpose(-2, -1)
231
+ out = torch.nn.functional.interpolate(out, size=identity.size()[-2])
232
+ out = out.transpose(-2, -1)
233
+
234
+ # Unmerging 2D Neighbor Merging
235
+ if self.q_mode == 'n2d':
236
+ out = nlc_to_nchw(out, reduced_hw)
237
+ out = torch.nn.functional.interpolate(out, size=hw_shape)
238
+ out = nchw_to_nlc(out)
239
+
240
+ return identity + self.dropout_layer(self.proj_drop(out))
241
+
242
+
243
+ class TransformerEncoderLayer(BaseModule):
244
+ """Implements one encoder layer in Segformer.
245
+
246
+ Args:
247
+ embed_dims (int): The feature dimension.
248
+ num_heads (int): Parallel attention heads.
249
+ feedforward_channels (int): The hidden dimension for FFNs.
250
+ drop_rate (float): Probability of an element to be zeroed.
251
+ after the feed forward layer. Default 0.0.
252
+ attn_drop_rate (float): The drop out rate for attention layer.
253
+ Default 0.0.
254
+ drop_path_rate (float): stochastic depth rate. Default 0.0.
255
+ qkv_bias (bool): enable bias for qkv if True.
256
+ Default: True.
257
+ act_cfg (dict): The activation config for FFNs.
258
+ Default: dict(type='GELU').
259
+ norm_cfg (dict): Config dict for normalization layer.
260
+ Default: dict(type='LN').
261
+ batch_first (bool): Key, Query and Value are shape of
262
+ (batch, n, embed_dim)
263
+ or (n, batch, embed_dim). Default: False.
264
+ init_cfg (dict, optional): Initialization config dict.
265
+ Default:None.
266
+ sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
267
+ Attention of Segformer. Default: 1.
268
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
269
+ some memory while slowing down the training speed. Default: False.
270
+ """
271
+
272
+ def __init__(self,
273
+ embed_dims,
274
+ num_heads,
275
+ feedforward_channels,
276
+ drop_rate=0.,
277
+ attn_drop_rate=0.,
278
+ drop_path_rate=0.,
279
+ qkv_bias=True,
280
+ tome_cfg=dict(),
281
+ act_cfg=dict(type='GELU'),
282
+ norm_cfg=dict(type='LN'),
283
+ batch_first=True,
284
+ sr_ratio=1,
285
+ with_cp=False):
286
+ super().__init__()
287
+
288
+ # The ret[0] of build_norm_layer is norm name.
289
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
290
+
291
+ self.attn = EfficientMultiheadAttention(
292
+ embed_dims=embed_dims,
293
+ num_heads=num_heads,
294
+ attn_drop=attn_drop_rate,
295
+ proj_drop=drop_rate,
296
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
297
+ batch_first=batch_first,
298
+ qkv_bias=qkv_bias,
299
+ tome_cfg=tome_cfg,
300
+ norm_cfg=norm_cfg,
301
+ sr_ratio=sr_ratio)
302
+
303
+ # The ret[0] of build_norm_layer is norm name.
304
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
305
+
306
+ self.ffn = MixFFN(
307
+ embed_dims=embed_dims,
308
+ feedforward_channels=feedforward_channels,
309
+ ffn_drop=drop_rate,
310
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
311
+ act_cfg=act_cfg)
312
+
313
+ self.with_cp = with_cp
314
+
315
+ def forward(self, x, hw_shape):
316
+
317
+ def _inner_forward(x):
318
+ x = self.attn(self.norm1(x), hw_shape, identity=x)
319
+ x = self.ffn(self.norm2(x), hw_shape, identity=x)
320
+ return x
321
+
322
+ if self.with_cp and x.requires_grad:
323
+ x = cp.checkpoint(_inner_forward, x)
324
+ else:
325
+ x = _inner_forward(x)
326
+ return x
327
+
328
+
329
+ @MODELS.register_module()
330
+ class MixVisionTransformer(BaseModule):
331
+ """The backbone of Segformer.
332
+
333
+ This backbone is the implementation of `SegFormer: Simple and
334
+ Efficient Design for Semantic Segmentation with
335
+ Transformers <https://arxiv.org/abs/2105.15203>`_.
336
+ Args:
337
+ in_channels (int): Number of input channels. Default: 3.
338
+ embed_dims (int): Embedding dimension. Default: 768.
339
+ num_stags (int): The num of stages. Default: 4.
340
+ num_layers (Sequence[int]): The layer number of each transformer encode
341
+ layer. Default: [3, 4, 6, 3].
342
+ num_heads (Sequence[int]): The attention heads of each transformer
343
+ encode layer. Default: [1, 2, 4, 8].
344
+ patch_sizes (Sequence[int]): The patch_size of each overlapped patch
345
+ embedding. Default: [7, 3, 3, 3].
346
+ strides (Sequence[int]): The stride of each overlapped patch embedding.
347
+ Default: [4, 2, 2, 2].
348
+ sr_ratios (Sequence[int]): The spatial reduction rate of each
349
+ transformer encode layer. Default: [8, 4, 2, 1].
350
+ out_indices (Sequence[int] | int): Output from which stages.
351
+ Default: (0, 1, 2, 3).
352
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
353
+ Default: 4.
354
+ qkv_bias (bool): Enable bias for qkv if True. Default: True.
355
+ drop_rate (float): Probability of an element to be zeroed.
356
+ Default 0.0
357
+ attn_drop_rate (float): The drop out rate for attention layer.
358
+ Default 0.0
359
+ drop_path_rate (float): stochastic depth rate. Default 0.0
360
+ norm_cfg (dict): Config dict for normalization layer.
361
+ Default: dict(type='LN')
362
+ act_cfg (dict): The activation config for FFNs.
363
+ Default: dict(type='GELU').
364
+ pretrained (str, optional): model pretrained path. Default: None.
365
+ init_cfg (dict or list[dict], optional): Initialization config dict.
366
+ Default: None.
367
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
368
+ some memory while slowing down the training speed. Default: False.
369
+ """
370
+
371
+ def __init__(self,
372
+ in_channels=3,
373
+ embed_dims=64,
374
+ num_stages=4,
375
+ num_layers=[3, 4, 6, 3],
376
+ num_heads=[1, 2, 4, 8],
377
+ patch_sizes=[7, 3, 3, 3],
378
+ strides=[4, 2, 2, 2],
379
+ sr_ratios=[8, 4, 2, 1],
380
+ out_indices=(0, 1, 2, 3),
381
+ mlp_ratio=4,
382
+ qkv_bias=True,
383
+ drop_rate=0.,
384
+ attn_drop_rate=0.,
385
+ drop_path_rate=0.,
386
+ tome_cfg=[dict(), dict(), dict(), dict()],
387
+ act_cfg=dict(type='GELU'),
388
+ norm_cfg=dict(type='LN', eps=1e-6),
389
+ init_cfg=None,
390
+ with_cp=False,
391
+ down_sample=False):
392
+ super().__init__(init_cfg=init_cfg)
393
+
394
+ self.embed_dims = embed_dims
395
+ self.num_stages = num_stages
396
+ self.num_layers = num_layers
397
+ self.num_heads = num_heads
398
+ self.patch_sizes = patch_sizes
399
+ self.strides = strides
400
+ self.sr_ratios = sr_ratios
401
+ self.with_cp = with_cp
402
+ self.down_sample = down_sample
403
+ assert num_stages == len(num_layers) == len(num_heads) \
404
+ == len(patch_sizes) == len(strides) == len(sr_ratios)
405
+
406
+ self.out_indices = out_indices
407
+ assert max(out_indices) < self.num_stages
408
+
409
+ # transformer encoder
410
+ dpr = [
411
+ x.item()
412
+ for x in torch.linspace(0, drop_path_rate, sum(num_layers))
413
+ ] # stochastic num_layer decay rule
414
+
415
+ cur = 0
416
+ self.layers = ModuleList()
417
+ for i, num_layer in enumerate(num_layers):
418
+ embed_dims_i = embed_dims * num_heads[i]
419
+ patch_embed = PatchEmbed(
420
+ in_channels=in_channels,
421
+ embed_dims=embed_dims_i,
422
+ kernel_size=patch_sizes[i],
423
+ stride=strides[i],
424
+ padding=patch_sizes[i] // 2,
425
+ norm_cfg=norm_cfg)
426
+ layer = ModuleList([
427
+ TransformerEncoderLayer(
428
+ embed_dims=embed_dims_i,
429
+ num_heads=num_heads[i],
430
+ feedforward_channels=mlp_ratio * embed_dims_i,
431
+ drop_rate=drop_rate,
432
+ attn_drop_rate=attn_drop_rate,
433
+ drop_path_rate=dpr[cur + idx],
434
+ qkv_bias=qkv_bias,
435
+ tome_cfg=tome_cfg[i],
436
+ act_cfg=act_cfg,
437
+ norm_cfg=norm_cfg,
438
+ with_cp=with_cp,
439
+ sr_ratio=sr_ratios[i]) for idx in range(num_layer)
440
+ ])
441
+ in_channels = embed_dims_i
442
+ # The ret[0] of build_norm_layer is norm name.
443
+ norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
444
+ self.layers.append(ModuleList([patch_embed, layer, norm]))
445
+ cur += num_layer
446
+
447
+ def init_weights(self):
448
+ if self.init_cfg is None:
449
+ for m in self.modules():
450
+ if isinstance(m, nn.Linear):
451
+ trunc_normal_init(m, std=.02, bias=0.)
452
+ elif isinstance(m, nn.LayerNorm):
453
+ constant_init(m, val=1.0, bias=0.)
454
+ elif isinstance(m, nn.Conv2d):
455
+ fan_out = m.kernel_size[0] * m.kernel_size[
456
+ 1] * m.out_channels
457
+ fan_out //= m.groups
458
+ normal_init(
459
+ m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
460
+ else:
461
+ super().init_weights()
462
+
463
+ def forward(self, x):
464
+ if self.down_sample:
465
+ x = torch.nn.functional.interpolate(x, scale_factor=(0.5, 0.5))
466
+ outs = []
467
+
468
+ for i, layer in enumerate(self.layers):
469
+ x, hw_shape = layer[0](x)
470
+ for block in layer[1]:
471
+ x = block(x, hw_shape)
472
+ x = layer[2](x)
473
+ x = nlc_to_nchw(x, hw_shape)
474
+ if i in self.out_indices:
475
+ outs.append(x)
476
+
477
+ return outs
segformer_plusplus/model/base_module.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from abc import ABCMeta
3
+ from collections import defaultdict
4
+ from typing import Iterable, List, Optional, Union, Callable
5
+ import warnings
6
+ from inspect import getfullargspec
7
+ import functools
8
+ import torch.nn as nn
9
+
10
+ from .utils import is_model_wrapper
11
+ from .weight_init import PretrainedInit, initialize, update_init_info
12
+ from ..utils.activation import build_dropout
13
+ from ..utils.registry import MODELS
14
+
15
+
16
+ class BaseModule(nn.Module, metaclass=ABCMeta):
17
+ """Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of
18
+ ``torch.nn.Module`` with additional functionality of parameter
19
+ initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly
20
+ adds three attributes.
21
+
22
+ - ``init_cfg``: the config to control the initialization.
23
+ - ``init_weights``: The function of parameter initialization and recording
24
+ initialization information.
25
+ - ``_params_init_info``: Used to track the parameter initialization
26
+ information. This attribute only exists during executing the
27
+ ``init_weights``.
28
+
29
+ Note:
30
+ :obj:`PretrainedInit` has a higher priority than any other
31
+ initializer. The loaded pretrained weights will overwrite
32
+ the previous initialized weights.
33
+
34
+ Args:
35
+ init_cfg (dict or List[dict], optional): Initialization config dict.
36
+ """
37
+
38
+ def __init__(self, init_cfg: Union[dict, List[dict], None] = None):
39
+ """Initialize BaseModule, inherited from `torch.nn.Module`"""
40
+
41
+ # NOTE init_cfg can be defined in different levels, but init_cfg
42
+ # in low levels has a higher priority.
43
+
44
+ super().__init__()
45
+ # define default value of init_cfg instead of hard code
46
+ # in init_weights() function
47
+ self._is_init = False
48
+
49
+ self.init_cfg = copy.deepcopy(init_cfg)
50
+
51
+ # Backward compatibility in derived classes
52
+ # if pretrained is not None:
53
+ # warnings.warn('DeprecationWarning: pretrained is a deprecated \
54
+ # key, please consider using init_cfg')
55
+ # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
56
+
57
+ @property
58
+ def is_init(self):
59
+ return self._is_init
60
+
61
+ @is_init.setter
62
+ def is_init(self, value):
63
+ self._is_init = value
64
+
65
+ def init_weights(self):
66
+ """Initialize the weights."""
67
+
68
+ is_top_level_module = False
69
+ # check if it is top-level module
70
+ if not hasattr(self, '_params_init_info'):
71
+ # The `_params_init_info` is used to record the initialization
72
+ # information of the parameters
73
+ # the key should be the obj:`nn.Parameter` of model and the value
74
+ # should be a dict containing
75
+ # - init_info (str): The string that describes the initialization.
76
+ # - tmp_mean_value (FloatTensor): The mean of the parameter,
77
+ # which indicates whether the parameter has been modified.
78
+ # this attribute would be deleted after all parameters
79
+ # is initialized.
80
+ self._params_init_info = defaultdict(dict)
81
+ is_top_level_module = True
82
+
83
+ # Initialize the `_params_init_info`,
84
+ # When detecting the `tmp_mean_value` of
85
+ # the corresponding parameter is changed, update related
86
+ # initialization information
87
+ for name, param in self.named_parameters():
88
+ self._params_init_info[param][
89
+ 'init_info'] = f'The value is the same before and ' \
90
+ f'after calling `init_weights` ' \
91
+ f'of {self.__class__.__name__} '
92
+ self._params_init_info[param][
93
+ 'tmp_mean_value'] = param.data.mean().cpu()
94
+
95
+ # pass `params_init_info` to all submodules
96
+ # All submodules share the same `params_init_info`,
97
+ # so it will be updated when parameters are
98
+ # modified at any level of the model.
99
+ for sub_module in self.modules():
100
+ sub_module._params_init_info = self._params_init_info
101
+
102
+ module_name = self.__class__.__name__
103
+ if not self._is_init:
104
+ if self.init_cfg:
105
+
106
+ init_cfgs = self.init_cfg
107
+ if isinstance(self.init_cfg, dict):
108
+ init_cfgs = [self.init_cfg]
109
+
110
+ # PretrainedInit has higher priority than any other init_cfg.
111
+ # Therefore we initialize `pretrained_cfg` last to overwrite
112
+ # the previous initialized weights.
113
+ # See details in https://github.com/open-mmlab/mmengine/issues/691 # noqa E501
114
+ other_cfgs = []
115
+ pretrained_cfg = []
116
+ for init_cfg in init_cfgs:
117
+ assert isinstance(init_cfg, dict)
118
+ if (init_cfg['type'] == 'Pretrained'
119
+ or init_cfg['type'] is PretrainedInit):
120
+ pretrained_cfg.append(init_cfg)
121
+ else:
122
+ other_cfgs.append(init_cfg)
123
+
124
+ initialize(self, other_cfgs)
125
+
126
+ for m in self.children():
127
+ if is_model_wrapper(m) and not hasattr(m, 'init_weights'):
128
+ m = m.module
129
+ if hasattr(m, 'init_weights') and not getattr(
130
+ m, 'is_init', False):
131
+ m.init_weights()
132
+ # users may overload the `init_weights`
133
+ update_init_info(
134
+ m,
135
+ init_info=f'Initialized by '
136
+ f'user-defined `init_weights`'
137
+ f' in {m.__class__.__name__} ')
138
+ if self.init_cfg and pretrained_cfg:
139
+ initialize(self, pretrained_cfg)
140
+ self._is_init = True
141
+
142
+ if is_top_level_module:
143
+ self._dump_init_info()
144
+
145
+ for sub_module in self.modules():
146
+ del sub_module._params_init_info
147
+
148
+ def __repr__(self):
149
+ s = super().__repr__()
150
+ if self.init_cfg:
151
+ s += f'\ninit_cfg={self.init_cfg}'
152
+ return s
153
+
154
+
155
+ def deprecated_api_warning(name_dict: dict,
156
+ cls_name: Optional[str] = None) -> Callable:
157
+ """A decorator to check if some arguments are deprecate and try to replace
158
+ deprecate src_arg_name to dst_arg_name.
159
+
160
+ Args:
161
+ name_dict(dict):
162
+ key (str): Deprecate argument names.
163
+ val (str): Expected argument names.
164
+
165
+ Returns:
166
+ func: New function.
167
+ """
168
+
169
+ def api_warning_wrapper(old_func):
170
+
171
+ @functools.wraps(old_func)
172
+ def new_func(*args, **kwargs):
173
+ # get the arg spec of the decorated method
174
+ args_info = getfullargspec(old_func)
175
+ # get name of the function
176
+ func_name = old_func.__name__
177
+ if cls_name is not None:
178
+ func_name = f'{cls_name}.{func_name}'
179
+ if args:
180
+ arg_names = args_info.args[:len(args)]
181
+ for src_arg_name, dst_arg_name in name_dict.items():
182
+ if src_arg_name in arg_names:
183
+ warnings.warn(
184
+ f'"{src_arg_name}" is deprecated in '
185
+ f'`{func_name}`, please use "{dst_arg_name}" '
186
+ 'instead', DeprecationWarning)
187
+ arg_names[arg_names.index(src_arg_name)] = dst_arg_name
188
+ if kwargs:
189
+ for src_arg_name, dst_arg_name in name_dict.items():
190
+ if src_arg_name in kwargs:
191
+ assert dst_arg_name not in kwargs, (
192
+ f'The expected behavior is to replace '
193
+ f'the deprecated key `{src_arg_name}` to '
194
+ f'new key `{dst_arg_name}`, but got them '
195
+ f'in the arguments at the same time, which '
196
+ f'is confusing. `{src_arg_name} will be '
197
+ f'deprecated in the future, please '
198
+ f'use `{dst_arg_name}` instead.')
199
+
200
+ warnings.warn(
201
+ f'"{src_arg_name}" is deprecated in '
202
+ f'`{func_name}`, please use "{dst_arg_name}" '
203
+ 'instead', DeprecationWarning)
204
+ kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
205
+
206
+ # apply converted arguments to the decorated method
207
+ output = old_func(*args, **kwargs)
208
+ return output
209
+
210
+ return new_func
211
+
212
+ return api_warning_wrapper
213
+
214
+
215
+ @MODELS.register_module()
216
+ class MultiheadAttention(BaseModule):
217
+ """A wrapper for ``torch.nn.MultiheadAttention``.
218
+
219
+ This module implements MultiheadAttention with identity connection,
220
+ and positional encoding is also passed as input.
221
+
222
+ Args:
223
+ embed_dims (int): The embedding dimension.
224
+ num_heads (int): Parallel attention heads.
225
+ attn_drop (float): A Dropout layer on attn_output_weights.
226
+ Default: 0.0.
227
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
228
+ Default: 0.0.
229
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
230
+ when adding the shortcut.
231
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
232
+ Default: None.
233
+ batch_first (bool): When it is True, Key, Query and Value are shape of
234
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
235
+ Default to False.
236
+ """
237
+
238
+ def __init__(self,
239
+ embed_dims,
240
+ num_heads,
241
+ attn_drop=0.,
242
+ proj_drop=0.,
243
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
244
+ init_cfg=None,
245
+ batch_first=False,
246
+ **kwargs):
247
+ super().__init__(init_cfg)
248
+ if 'dropout' in kwargs:
249
+ warnings.warn(
250
+ 'The arguments `dropout` in MultiheadAttention '
251
+ 'has been deprecated, now you can separately '
252
+ 'set `attn_drop`(float), proj_drop(float), '
253
+ 'and `dropout_layer`(dict) ', DeprecationWarning)
254
+ attn_drop = kwargs['dropout']
255
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
256
+
257
+ self.embed_dims = embed_dims
258
+ self.num_heads = num_heads
259
+ self.batch_first = batch_first
260
+
261
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
262
+ **kwargs)
263
+
264
+ self.proj_drop = nn.Dropout(proj_drop)
265
+ self.dropout_layer = build_dropout(
266
+ dropout_layer) if dropout_layer else nn.Identity()
267
+
268
+ @deprecated_api_warning({'residual': 'identity'},
269
+ cls_name='MultiheadAttention')
270
+ def forward(self,
271
+ query,
272
+ key=None,
273
+ value=None,
274
+ identity=None,
275
+ query_pos=None,
276
+ key_pos=None,
277
+ attn_mask=None,
278
+ key_padding_mask=None,
279
+ **kwargs):
280
+ """Forward function for `MultiheadAttention`.
281
+
282
+ **kwargs allow passing a more general data flow when combining
283
+ with other operations in `transformerlayer`.
284
+
285
+ Args:
286
+ query (Tensor): The input query with shape [num_queries, bs,
287
+ embed_dims] if self.batch_first is False, else
288
+ [bs, num_queries embed_dims].
289
+ key (Tensor): The key tensor with shape [num_keys, bs,
290
+ embed_dims] if self.batch_first is False, else
291
+ [bs, num_keys, embed_dims] .
292
+ If None, the ``query`` will be used. Defaults to None.
293
+ value (Tensor): The value tensor with same shape as `key`.
294
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
295
+ If None, the `key` will be used.
296
+ identity (Tensor): This tensor, with the same shape as x,
297
+ will be used for the identity link.
298
+ If None, `x` will be used. Defaults to None.
299
+ query_pos (Tensor): The positional encoding for query, with
300
+ the same shape as `x`. If not None, it will
301
+ be added to `x` before forward function. Defaults to None.
302
+ key_pos (Tensor): The positional encoding for `key`, with the
303
+ same shape as `key`. Defaults to None. If not None, it will
304
+ be added to `key` before forward function. If None, and
305
+ `query_pos` has the same shape as `key`, then `query_pos`
306
+ will be used for `key_pos`. Defaults to None.
307
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
308
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
309
+ Defaults to None.
310
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
311
+ Defaults to None.
312
+
313
+ Returns:
314
+ Tensor: forwarded results with shape
315
+ [num_queries, bs, embed_dims]
316
+ if self.batch_first is False, else
317
+ [bs, num_queries embed_dims].
318
+ """
319
+
320
+ if key is None:
321
+ key = query
322
+ if value is None:
323
+ value = key
324
+ if identity is None:
325
+ identity = query
326
+ if key_pos is None:
327
+ if query_pos is not None:
328
+ # use query_pos if key_pos is not available
329
+ if query_pos.shape == key.shape:
330
+ key_pos = query_pos
331
+ if query_pos is not None:
332
+ query = query + query_pos
333
+ if key_pos is not None:
334
+ key = key + key_pos
335
+
336
+ # Because the dataflow('key', 'query', 'value') of
337
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
338
+ # embed_dims), We should adjust the shape of dataflow from
339
+ # batch_first (batch, num_query, embed_dims) to num_query_first
340
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
341
+ # from num_query_first to batch_first.
342
+ if self.batch_first:
343
+ query = query.transpose(0, 1)
344
+ key = key.transpose(0, 1)
345
+ value = value.transpose(0, 1)
346
+
347
+ out = self.attn(
348
+ query=query,
349
+ key=key,
350
+ value=value,
351
+ attn_mask=attn_mask,
352
+ key_padding_mask=key_padding_mask)[0]
353
+
354
+ if self.batch_first:
355
+ out = out.transpose(0, 1)
356
+
357
+ return identity + self.dropout_layer(self.proj_drop(out))
358
+
359
+
360
+ class ModuleList(BaseModule, nn.ModuleList):
361
+ """ModuleList in openmmlab.
362
+
363
+ Ensures that all modules in ``ModuleList`` have a different initialization
364
+ strategy than the outer model
365
+
366
+ Args:
367
+ modules (iterable, optional): An iterable of modules to add.
368
+ init_cfg (dict, optional): Initialization config dict.
369
+ """
370
+
371
+ def __init__(self,
372
+ modules: Optional[Iterable] = None,
373
+ init_cfg: Optional[dict] = None):
374
+ BaseModule.__init__(self, init_cfg)
375
+ nn.ModuleList.__init__(self, modules)
376
+
377
+
378
+ class Sequential(BaseModule, nn.Sequential):
379
+ """Sequential module in openmmlab.
380
+
381
+ Ensures that all modules in ``Sequential`` have a different initialization
382
+ strategy than the outer model
383
+
384
+ Args:
385
+ init_cfg (dict, optional): Initialization config dict.
386
+ """
387
+
388
+ def __init__(self, *args, init_cfg: Optional[dict] = None):
389
+ BaseModule.__init__(self, init_cfg)
390
+ nn.Sequential.__init__(self, *args)
segformer_plusplus/model/head/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .segformer_head import SegformerHead
2
+
3
+ __all__ = ['SegformerHead']