|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from einops import rearrange | 
					
						
						|  |  | 
					
						
						|  | from .structured_linear import StructuredLinear | 
					
						
						|  | from .blockdiag_multiply import blockdiag_multiply | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BlockdiagLinear(StructuredLinear): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, *args, nblocks=4, shuffle=False, **kwargs): | 
					
						
						|  | """shuffle: apply channel_shuffle operation before the matmul as in ShuffleNet | 
					
						
						|  | """ | 
					
						
						|  | super().__init__(*args, **kwargs) | 
					
						
						|  | in_blksz = int(math.ceil(self.in_features / nblocks)) | 
					
						
						|  | out_blksz = int(math.ceil(self.out_features / nblocks)) | 
					
						
						|  | self.in_features_extended = in_blksz * nblocks | 
					
						
						|  | self.out_features_extended = out_blksz * nblocks | 
					
						
						|  | self.shuffle = shuffle | 
					
						
						|  | self.weight = nn.Parameter(torch.empty(nblocks, out_blksz, in_blksz)) | 
					
						
						|  | self.reset_parameters() | 
					
						
						|  |  | 
					
						
						|  | def set_weights_from_dense_init(self, dense_init_fn_): | 
					
						
						|  | dense_weight = torch.empty(self.out_features_extended, self.in_features_extended, | 
					
						
						|  | device=self.weight.device, dtype=self.weight.dtype) | 
					
						
						|  | dense_init_fn_(dense_weight) | 
					
						
						|  |  | 
					
						
						|  | scaling = math.sqrt(dense_weight.numel() / self.weight.numel()) | 
					
						
						|  | dense_weight *= scaling | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | nblocks = self.weight.shape[0] | 
					
						
						|  | self.weight.copy_(rearrange(dense_weight, '(b o) (b1 i) -> b b1 o i', | 
					
						
						|  | b=nblocks, b1=nblocks)[0]) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def saving(self): | 
					
						
						|  | return self.weight.numel() / (self.in_features * self.out_features) | 
					
						
						|  |  | 
					
						
						|  | def forward_matmul(self, x): | 
					
						
						|  | x = self.preprocess(x) | 
					
						
						|  | if self.shuffle: | 
					
						
						|  | x = rearrange(x, '... (group c_per_group) -> ... (c_per_group group)', | 
					
						
						|  | group=self.weight.shape[0]) | 
					
						
						|  | output = blockdiag_multiply(x, self.weight) | 
					
						
						|  | return self.postprocess(output) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BlockdiagSparsityConfig: | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, nblocks, block=32, global_size=0): | 
					
						
						|  | """shuffle: apply channel_shuffle operation before the matmul as in ShuffleNet | 
					
						
						|  | """ | 
					
						
						|  | self.nblocks = nblocks | 
					
						
						|  | self.block = block | 
					
						
						|  | self.global_size = global_size | 
					
						
						|  |  | 
					
						
						|  | def make_layout(self, out_features, in_features): | 
					
						
						|  | assert out_features % self.block == 0 and in_features % self.block == 0 | 
					
						
						|  | assert out_features % self.nblocks == 0 and in_features % self.nblocks == 0 | 
					
						
						|  | layout = torch.block_diag(*[torch.ones(out_features // self.nblocks, | 
					
						
						|  | in_features // self.nblocks, | 
					
						
						|  | dtype=torch.int32)] * self.nblocks) | 
					
						
						|  | if self.global_size > 0: | 
					
						
						|  | layout[:self.global_size] = 1 | 
					
						
						|  | layout[:, :self.global_size] = 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | layout = rearrange(layout, '(p blksz) (r blksz1) -> p r (blksz blksz1)', | 
					
						
						|  | blksz=self.block, blksz1=self.block) | 
					
						
						|  | return (layout > 0).any(dim=-1).int() |