Upload folder using huggingface_hub
Browse files- ersvr/models/ersvr.py +49 -0
- ersvr/models/feature_alignment.py +24 -0
- ersvr/models/mbd.py +28 -0
- ersvr/models/sr_network.py +44 -0
- ersvr/models/student.py +59 -0
- ersvr/models/upsampling.py +33 -0
ersvr/models/ersvr.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from einops import rearrange
|
4 |
+
from .feature_alignment import FeatureAlignmentBlock
|
5 |
+
from .sr_network import SRNetwork
|
6 |
+
|
7 |
+
class ERSVR(nn.Module):
|
8 |
+
"""Real-time Video Super Resolution Network using Recurrent Multi-Branch Dilated Convolutions"""
|
9 |
+
def __init__(self, scale_factor=4):
|
10 |
+
super(ERSVR, self).__init__()
|
11 |
+
|
12 |
+
self.scale_factor = scale_factor
|
13 |
+
|
14 |
+
# Feature alignment block
|
15 |
+
self.feature_alignment = FeatureAlignmentBlock(in_channels=9, out_channels=64)
|
16 |
+
|
17 |
+
# SR network
|
18 |
+
self.sr_network = SRNetwork(in_channels=64, out_channels=3)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
# Input shape: (B, 3, 3, H, W) - batch of 3 RGB frames
|
22 |
+
batch_size, num_frames, channels, height, width = x.shape
|
23 |
+
|
24 |
+
# Rearrange input to (B, 9, H, W)
|
25 |
+
x = rearrange(x, 'b n c h w -> b (n c) h w')
|
26 |
+
|
27 |
+
# Extract center frame for residual connection
|
28 |
+
center_frame = x[:, 3:6, :, :] # RGB channels of center frame
|
29 |
+
|
30 |
+
# Bicubic upsampling of center frame for residual connection
|
31 |
+
bicubic = F.interpolate(
|
32 |
+
center_frame,
|
33 |
+
scale_factor=self.scale_factor,
|
34 |
+
mode='bicubic',
|
35 |
+
align_corners=False
|
36 |
+
)
|
37 |
+
|
38 |
+
# Feature alignment
|
39 |
+
features = self.feature_alignment(x)
|
40 |
+
|
41 |
+
# SR network
|
42 |
+
output = self.sr_network(features, bicubic)
|
43 |
+
|
44 |
+
# Ensure output and bicubic have the same dimensions
|
45 |
+
if output.shape != bicubic.shape:
|
46 |
+
print(f"Output shape: {output.shape}, Bicubic shape: {bicubic.shape}")
|
47 |
+
raise ValueError("Output and bicubic tensors must have the same dimensions")
|
48 |
+
|
49 |
+
return output
|
ersvr/models/feature_alignment.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from .mbd import MBDModule
|
3 |
+
|
4 |
+
class FeatureAlignmentBlock(nn.Module):
|
5 |
+
"""Feature Alignment Block for processing concatenated frames"""
|
6 |
+
def __init__(self, in_channels=9, out_channels=64):
|
7 |
+
super(FeatureAlignmentBlock, self).__init__()
|
8 |
+
|
9 |
+
self.conv_layers = nn.Sequential(
|
10 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
11 |
+
nn.ReLU(inplace=True),
|
12 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
13 |
+
nn.ReLU(inplace=True),
|
14 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
|
18 |
+
self.mbd = MBDModule(out_channels, out_channels)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
# Input shape: (B, 9, H, W) - concatenated frames
|
22 |
+
x = self.conv_layers(x)
|
23 |
+
x = self.mbd(x)
|
24 |
+
return x
|
ersvr/models/mbd.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class MBDModule(nn.Module):
|
5 |
+
"""Multi-Branch Dilated Convolution Module"""
|
6 |
+
def __init__(self, in_channels, out_channels):
|
7 |
+
super(MBDModule, self).__init__()
|
8 |
+
|
9 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
10 |
+
|
11 |
+
self.dilated_convs = nn.ModuleList([
|
12 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
13 |
+
padding=d, dilation=d) for d in [1, 2, 4]
|
14 |
+
])
|
15 |
+
|
16 |
+
self.fusion = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x = self.pointwise(x)
|
20 |
+
|
21 |
+
dilated_outputs = []
|
22 |
+
for conv in self.dilated_convs:
|
23 |
+
dilated_outputs.append(conv(x))
|
24 |
+
|
25 |
+
x = torch.cat(dilated_outputs, dim=1)
|
26 |
+
x = self.fusion(x)
|
27 |
+
|
28 |
+
return x
|
ersvr/models/sr_network.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .upsampling import UpsamplingBlock
|
4 |
+
|
5 |
+
class SRNetwork(nn.Module):
|
6 |
+
"""Super Resolution Network with ESPCN-like backbone"""
|
7 |
+
def __init__(self, in_channels=64, out_channels=3):
|
8 |
+
super(SRNetwork, self).__init__()
|
9 |
+
|
10 |
+
self.conv_layers = nn.Sequential(
|
11 |
+
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
|
12 |
+
nn.ReLU(inplace=True),
|
13 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
16 |
+
nn.ReLU(inplace=True),
|
17 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
18 |
+
nn.ReLU(inplace=True),
|
19 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
24 |
+
nn.ReLU(inplace=True),
|
25 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
26 |
+
nn.ReLU(inplace=True),
|
27 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
28 |
+
nn.ReLU(inplace=True)
|
29 |
+
)
|
30 |
+
|
31 |
+
self.upsampling = UpsamplingBlock(64)
|
32 |
+
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
|
33 |
+
|
34 |
+
def forward(self, x, bicubic):
|
35 |
+
x = self.conv_layers(x)
|
36 |
+
|
37 |
+
print(f"Before upsampling: {x.shape}")
|
38 |
+
x = self.upsampling(x)
|
39 |
+
print(f"After upsampling: {x.shape}")
|
40 |
+
print(f"Bicubic shape: {bicubic.shape}")
|
41 |
+
|
42 |
+
x = self.final_conv(x)
|
43 |
+
x = x + bicubic
|
44 |
+
return x
|
ersvr/models/student.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class DepthwiseSeparableConv(nn.Module):
|
4 |
+
"""
|
5 |
+
Depthwise Separable Convolution Block for efficiency.
|
6 |
+
Consists of a depthwise convolution followed by a pointwise convolution.
|
7 |
+
"""
|
8 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
|
9 |
+
super().__init__()
|
10 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
|
11 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
|
12 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
13 |
+
self.relu = nn.ReLU(inplace=True)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
x = self.depthwise(x)
|
17 |
+
x = self.pointwise(x)
|
18 |
+
x = self.bn(x)
|
19 |
+
x = self.relu(x)
|
20 |
+
return x
|
21 |
+
|
22 |
+
class StudentSRNet(nn.Module):
|
23 |
+
"""
|
24 |
+
Ultra-lightweight Student Model for Video Super-Resolution.
|
25 |
+
- Input: (B, 3, 3, H, W) # 3 frames, 3 channels each
|
26 |
+
- Output: (B, 3, H*4, W*4) # Super-resolved center frame
|
27 |
+
Designed for real-time, mobile/edge deployment.
|
28 |
+
"""
|
29 |
+
def __init__(self, scale_factor=4):
|
30 |
+
super().__init__()
|
31 |
+
self.scale_factor = scale_factor
|
32 |
+
self.input_conv = nn.Conv2d(9, 16, 3, padding=1)
|
33 |
+
self.block1 = DepthwiseSeparableConv(16, 32)
|
34 |
+
self.block2 = DepthwiseSeparableConv(32, 32)
|
35 |
+
self.block3 = DepthwiseSeparableConv(32, 16)
|
36 |
+
self.upsample1 = nn.Sequential(
|
37 |
+
nn.Conv2d(16, 64, 3, padding=1),
|
38 |
+
nn.PixelShuffle(2),
|
39 |
+
nn.ReLU(inplace=True)
|
40 |
+
)
|
41 |
+
self.upsample2 = nn.Sequential(
|
42 |
+
nn.Conv2d(16, 64, 3, padding=1),
|
43 |
+
nn.PixelShuffle(2),
|
44 |
+
nn.ReLU(inplace=True)
|
45 |
+
)
|
46 |
+
self.output_conv = nn.Conv2d(16, 3, 3, padding=1)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
# x: (B, 3, 3, H, W) -> (B, 9, H, W)
|
50 |
+
b, n, c, h, w = x.shape
|
51 |
+
x = x.reshape(b, n * c, h, w)
|
52 |
+
x = self.input_conv(x)
|
53 |
+
x = self.block1(x)
|
54 |
+
x = self.block2(x)
|
55 |
+
x = self.block3(x)
|
56 |
+
x = self.upsample1(x)
|
57 |
+
x = self.upsample2(x)
|
58 |
+
x = self.output_conv(x)
|
59 |
+
return x
|
ersvr/models/upsampling.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class SubpixelUpsampling(nn.Module):
|
4 |
+
"""Subpixel Upsampling Module using PixelShuffle"""
|
5 |
+
def __init__(self, in_channels, scale_factor=2):
|
6 |
+
super(SubpixelUpsampling, self).__init__()
|
7 |
+
|
8 |
+
self.scale_factor = scale_factor
|
9 |
+
self.conv = nn.Conv2d(
|
10 |
+
in_channels,
|
11 |
+
in_channels * (scale_factor ** 2),
|
12 |
+
kernel_size=3,
|
13 |
+
padding=1
|
14 |
+
)
|
15 |
+
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.conv(x)
|
19 |
+
x = self.pixel_shuffle(x)
|
20 |
+
return x
|
21 |
+
|
22 |
+
class UpsamplingBlock(nn.Module):
|
23 |
+
"""Block for 4x upsampling using two SubpixelUpsampling modules"""
|
24 |
+
def __init__(self, in_channels):
|
25 |
+
super(UpsamplingBlock, self).__init__()
|
26 |
+
|
27 |
+
self.upsample1 = SubpixelUpsampling(in_channels)
|
28 |
+
self.upsample2 = SubpixelUpsampling(in_channels)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.upsample1(x)
|
32 |
+
x = self.upsample2(x)
|
33 |
+
return x
|