Abhinavexists commited on
Commit
5b9bb29
·
verified ·
1 Parent(s): c7cbc88

Upload folder using huggingface_hub

Browse files
ersvr/models/ersvr.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+ from .feature_alignment import FeatureAlignmentBlock
5
+ from .sr_network import SRNetwork
6
+
7
+ class ERSVR(nn.Module):
8
+ """Real-time Video Super Resolution Network using Recurrent Multi-Branch Dilated Convolutions"""
9
+ def __init__(self, scale_factor=4):
10
+ super(ERSVR, self).__init__()
11
+
12
+ self.scale_factor = scale_factor
13
+
14
+ # Feature alignment block
15
+ self.feature_alignment = FeatureAlignmentBlock(in_channels=9, out_channels=64)
16
+
17
+ # SR network
18
+ self.sr_network = SRNetwork(in_channels=64, out_channels=3)
19
+
20
+ def forward(self, x):
21
+ # Input shape: (B, 3, 3, H, W) - batch of 3 RGB frames
22
+ batch_size, num_frames, channels, height, width = x.shape
23
+
24
+ # Rearrange input to (B, 9, H, W)
25
+ x = rearrange(x, 'b n c h w -> b (n c) h w')
26
+
27
+ # Extract center frame for residual connection
28
+ center_frame = x[:, 3:6, :, :] # RGB channels of center frame
29
+
30
+ # Bicubic upsampling of center frame for residual connection
31
+ bicubic = F.interpolate(
32
+ center_frame,
33
+ scale_factor=self.scale_factor,
34
+ mode='bicubic',
35
+ align_corners=False
36
+ )
37
+
38
+ # Feature alignment
39
+ features = self.feature_alignment(x)
40
+
41
+ # SR network
42
+ output = self.sr_network(features, bicubic)
43
+
44
+ # Ensure output and bicubic have the same dimensions
45
+ if output.shape != bicubic.shape:
46
+ print(f"Output shape: {output.shape}, Bicubic shape: {bicubic.shape}")
47
+ raise ValueError("Output and bicubic tensors must have the same dimensions")
48
+
49
+ return output
ersvr/models/feature_alignment.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .mbd import MBDModule
3
+
4
+ class FeatureAlignmentBlock(nn.Module):
5
+ """Feature Alignment Block for processing concatenated frames"""
6
+ def __init__(self, in_channels=9, out_channels=64):
7
+ super(FeatureAlignmentBlock, self).__init__()
8
+
9
+ self.conv_layers = nn.Sequential(
10
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
11
+ nn.ReLU(inplace=True),
12
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
13
+ nn.ReLU(inplace=True),
14
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+
18
+ self.mbd = MBDModule(out_channels, out_channels)
19
+
20
+ def forward(self, x):
21
+ # Input shape: (B, 9, H, W) - concatenated frames
22
+ x = self.conv_layers(x)
23
+ x = self.mbd(x)
24
+ return x
ersvr/models/mbd.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MBDModule(nn.Module):
5
+ """Multi-Branch Dilated Convolution Module"""
6
+ def __init__(self, in_channels, out_channels):
7
+ super(MBDModule, self).__init__()
8
+
9
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
10
+
11
+ self.dilated_convs = nn.ModuleList([
12
+ nn.Conv2d(out_channels, out_channels, kernel_size=3,
13
+ padding=d, dilation=d) for d in [1, 2, 4]
14
+ ])
15
+
16
+ self.fusion = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1)
17
+
18
+ def forward(self, x):
19
+ x = self.pointwise(x)
20
+
21
+ dilated_outputs = []
22
+ for conv in self.dilated_convs:
23
+ dilated_outputs.append(conv(x))
24
+
25
+ x = torch.cat(dilated_outputs, dim=1)
26
+ x = self.fusion(x)
27
+
28
+ return x
ersvr/models/sr_network.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .upsampling import UpsamplingBlock
4
+
5
+ class SRNetwork(nn.Module):
6
+ """Super Resolution Network with ESPCN-like backbone"""
7
+ def __init__(self, in_channels=64, out_channels=3):
8
+ super(SRNetwork, self).__init__()
9
+
10
+ self.conv_layers = nn.Sequential(
11
+ nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
14
+ nn.ReLU(inplace=True),
15
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
20
+ nn.ReLU(inplace=True),
21
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
22
+ nn.ReLU(inplace=True),
23
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
26
+ nn.ReLU(inplace=True),
27
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
28
+ nn.ReLU(inplace=True)
29
+ )
30
+
31
+ self.upsampling = UpsamplingBlock(64)
32
+ self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
33
+
34
+ def forward(self, x, bicubic):
35
+ x = self.conv_layers(x)
36
+
37
+ print(f"Before upsampling: {x.shape}")
38
+ x = self.upsampling(x)
39
+ print(f"After upsampling: {x.shape}")
40
+ print(f"Bicubic shape: {bicubic.shape}")
41
+
42
+ x = self.final_conv(x)
43
+ x = x + bicubic
44
+ return x
ersvr/models/student.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class DepthwiseSeparableConv(nn.Module):
4
+ """
5
+ Depthwise Separable Convolution Block for efficiency.
6
+ Consists of a depthwise convolution followed by a pointwise convolution.
7
+ """
8
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
9
+ super().__init__()
10
+ self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
11
+ self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
12
+ self.bn = nn.BatchNorm2d(out_channels)
13
+ self.relu = nn.ReLU(inplace=True)
14
+
15
+ def forward(self, x):
16
+ x = self.depthwise(x)
17
+ x = self.pointwise(x)
18
+ x = self.bn(x)
19
+ x = self.relu(x)
20
+ return x
21
+
22
+ class StudentSRNet(nn.Module):
23
+ """
24
+ Ultra-lightweight Student Model for Video Super-Resolution.
25
+ - Input: (B, 3, 3, H, W) # 3 frames, 3 channels each
26
+ - Output: (B, 3, H*4, W*4) # Super-resolved center frame
27
+ Designed for real-time, mobile/edge deployment.
28
+ """
29
+ def __init__(self, scale_factor=4):
30
+ super().__init__()
31
+ self.scale_factor = scale_factor
32
+ self.input_conv = nn.Conv2d(9, 16, 3, padding=1)
33
+ self.block1 = DepthwiseSeparableConv(16, 32)
34
+ self.block2 = DepthwiseSeparableConv(32, 32)
35
+ self.block3 = DepthwiseSeparableConv(32, 16)
36
+ self.upsample1 = nn.Sequential(
37
+ nn.Conv2d(16, 64, 3, padding=1),
38
+ nn.PixelShuffle(2),
39
+ nn.ReLU(inplace=True)
40
+ )
41
+ self.upsample2 = nn.Sequential(
42
+ nn.Conv2d(16, 64, 3, padding=1),
43
+ nn.PixelShuffle(2),
44
+ nn.ReLU(inplace=True)
45
+ )
46
+ self.output_conv = nn.Conv2d(16, 3, 3, padding=1)
47
+
48
+ def forward(self, x):
49
+ # x: (B, 3, 3, H, W) -> (B, 9, H, W)
50
+ b, n, c, h, w = x.shape
51
+ x = x.reshape(b, n * c, h, w)
52
+ x = self.input_conv(x)
53
+ x = self.block1(x)
54
+ x = self.block2(x)
55
+ x = self.block3(x)
56
+ x = self.upsample1(x)
57
+ x = self.upsample2(x)
58
+ x = self.output_conv(x)
59
+ return x
ersvr/models/upsampling.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class SubpixelUpsampling(nn.Module):
4
+ """Subpixel Upsampling Module using PixelShuffle"""
5
+ def __init__(self, in_channels, scale_factor=2):
6
+ super(SubpixelUpsampling, self).__init__()
7
+
8
+ self.scale_factor = scale_factor
9
+ self.conv = nn.Conv2d(
10
+ in_channels,
11
+ in_channels * (scale_factor ** 2),
12
+ kernel_size=3,
13
+ padding=1
14
+ )
15
+ self.pixel_shuffle = nn.PixelShuffle(scale_factor)
16
+
17
+ def forward(self, x):
18
+ x = self.conv(x)
19
+ x = self.pixel_shuffle(x)
20
+ return x
21
+
22
+ class UpsamplingBlock(nn.Module):
23
+ """Block for 4x upsampling using two SubpixelUpsampling modules"""
24
+ def __init__(self, in_channels):
25
+ super(UpsamplingBlock, self).__init__()
26
+
27
+ self.upsample1 = SubpixelUpsampling(in_channels)
28
+ self.upsample2 = SubpixelUpsampling(in_channels)
29
+
30
+ def forward(self, x):
31
+ x = self.upsample1(x)
32
+ x = self.upsample2(x)
33
+ return x