TinySR 2 Readme!
TinySR 2 is the world's smallest super-resolution model (for audio) that I trained!
It's only 34.3 KB (for the GPT-5.3 Mini / 28-channel version) in size and it can easily run on a CPU or GPU!
There are other checkpoints as well, a version by Claude 4.6 Sonnet (larger at 216 KB with GAN) and a 12-channel version (14.3 KB, trained for ~380 steps)
Details
- Input Sample Rate: 16 kHz
- Output Sample Rate: 48 kHz
- Size: Only 34.3 KB
The model works on mono (and stereo, by upscaling both channels) audio!
Variants
| File Name | Description | Architecture | Audio Used | Steps/Epochs Trained | Size |
|---|---|---|---|---|---|
| TinySR 2 - Run 2 - 2500 Steps.pt | GPT-5.3 Mini version, 28 channels | 28-channel Conv1D, SiLU (was ReLU on TinySR 1) | 77 minutes of AI audio* | Run 2, 2500 steps | 34.3KB |
| TinySR 2 - Run 2 - 2000 Steps.pt | GPT-5.3 Mini version, 28 channels | 28-channel Conv1D, SiLU | 77 minutes of AI audio* | Run 2, 2000 steps | 32.3KB |
| TinySR 2 - Run 2 - 1000 Steps.pt | GPT-5.3 Mini version, 28 channels | 28-channel Conv1D, SiLU | 77 minutes of AI audio* | Run 2, 1000 steps | 32.3KB |
| TinySR 2 - Run 1 - 1000 Steps.pt | GPT-5.3 Mini version, 28 channels | 28-channel Conv1D, SiLU | 77 minutes of AI audio* | Run 1, 1000 steps | 32.3KB |
| TinySR 2 - 12 Channels - 380 Steps.pt | GPT-5.3 Mini version, 12 channels | 12-channel Conv1D, SiLU | 77 minutes of AI audio* | 380 steps | 14.3KB |
| TinySR 2 - Claude 4.6 Sonnet.pt | Claude 4.6 Sonnet version | 64-channel, 8-block Conv1D, SiLU, GAN, pixel-shuffle upsampling | FMA Small + LJSpeech | 12 epochs | 216KB |
*it was a mix of speech (ListenHub, Google's NotebookLM Audio Overviews) + music (Suno 3.5, Suno 4.5-all, Mureka 8, Mubert)
Code
For the GPT-5.3 Mini version:
# =========================================================
# TinySR 2.5
# ~8K parameter audio super-resolution model
# 16 kHz -> 48 kHz
# Optimized for:
# - Tiny size
# - Better receptive field
# - Fast inference
# - Better perceptual quality
# =========================================================
import os
import random
import torch
import soundfile as sf
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# =========================================================
# 1️⃣ Residual Dilated Block
# =========================================================
class TinyBlock(nn.Module):
def __init__(self, channels, dilation):
super().__init__()
self.dw = nn.Conv1d(
channels,
channels,
kernel_size=5,
padding=dilation * 2,
dilation=dilation,
groups=channels
)
self.pw = nn.Conv1d(
channels,
channels,
kernel_size=1
)
self.act = nn.SiLU()
def forward(self, x):
y = self.dw(x)
y = self.pw(y)
y = self.act(y)
return x + y
# =========================================================
# 2️⃣ TinySR 2.5 Model
# =========================================================
class TinySR25(nn.Module):
def __init__(self, channels=28):
super().__init__()
self.input = nn.Conv1d(
1,
channels,
kernel_size=1
)
# Massive receptive field with tiny params
dilations = [1, 2, 4, 8, 16, 32]
self.blocks = nn.ModuleList([
TinyBlock(channels, d)
for d in dilations
])
self.output = nn.Conv1d(
channels,
1,
kernel_size=1
)
def forward(self, x):
# =================================================
# Linear interpolation baseline
# 16k -> 48k
# =================================================
base = F.interpolate(
x,
scale_factor=3,
mode="linear",
align_corners=False
)
y = self.input(base)
for block in self.blocks:
y = block(y)
residual = self.output(y)
# Residual prediction
return base + residual
# =========================================================
# 3️⃣ Dataset Loader
# =========================================================
class AudioPairDataset(Dataset):
def __init__(
self,
lr_dir,
hr_dir,
segment_seconds=1.0
):
self.lr_dir = lr_dir
self.hr_dir = hr_dir
self.files = sorted(os.listdir(lr_dir))
# 16 kHz input
self.seg_len_lr = int(
16000 * segment_seconds
)
# 48 kHz target
self.seg_len_hr = self.seg_len_lr * 3
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
name = self.files[idx]
# -------------------------------------------------
# Load low-resolution audio
# -------------------------------------------------
lr, _ = sf.read(
os.path.join(self.lr_dir, name),
dtype="float32"
)
# -------------------------------------------------
# Load high-resolution audio
# -------------------------------------------------
hr, _ = sf.read(
os.path.join(self.hr_dir, name),
dtype="float32"
)
# -------------------------------------------------
# Convert stereo -> mono if needed
# -------------------------------------------------
if len(lr.shape) > 1:
lr = lr.mean(axis=1)
if len(hr.shape) > 1:
hr = hr.mean(axis=1)
# -------------------------------------------------
# Random crop
# -------------------------------------------------
if len(lr) > self.seg_len_lr:
start = random.randint(
0,
len(lr) - self.seg_len_lr
)
lr = lr[
start:
start + self.seg_len_lr
]
hr = hr[
start * 3:
(start * 3) + self.seg_len_hr
]
return (
torch.from_numpy(lr).unsqueeze(0),
torch.from_numpy(hr).unsqueeze(0)
)
# =========================================================
# 4️⃣ Multi-Resolution STFT Loss
# =========================================================
def stft_loss(x, y):
loss = 0.0
fft_sizes = [256, 512, 1024]
for fft in fft_sizes:
window = torch.hann_window(
fft,
device=x.device
)
X = torch.stft(
x.squeeze(1),
n_fft=fft,
hop_length=fft // 4,
window=window,
return_complex=True
)
Y = torch.stft(
y.squeeze(1),
n_fft=fft,
hop_length=fft // 4,
window=window,
return_complex=True
)
loss += (
X.abs() - Y.abs()
).abs().mean()
return loss / len(fft_sizes)
# =========================================================
# 5️⃣ Setup
# =========================================================
device = (
"cuda"
if torch.cuda.is_available()
else "cpu"
)
print("Using device:", device)
model = TinySR25(
channels=28
).to(device)
# ---------------------------------------------------------
# Parameter count
# ---------------------------------------------------------
params = sum(
p.numel()
for p in model.parameters()
)
print(f"Parameters: {params}")
# ---------------------------------------------------------
# Optimizer
# ---------------------------------------------------------
opt = torch.optim.AdamW(
model.parameters(),
lr=2e-4
)
# =========================================================
# 6️⃣ Dataset
# =========================================================
dataset = AudioPairDataset(
lr_dir="data/lr",
hr_dir="data/hr",
segment_seconds=1.0
)
loader = DataLoader(
dataset,
batch_size=16,
shuffle=True,
num_workers=2,
pin_memory=True
)
# =========================================================
# 7️⃣ Training Loop
# =========================================================
step = 0
max_steps = 100000
model.train()
while True:
for lr, hr in loader:
lr = lr.to(device)
hr = hr.to(device)
# -------------------------------------------------
# Forward
# -------------------------------------------------
opt.zero_grad()
out = model(lr)
# -------------------------------------------------
# Losses
# -------------------------------------------------
l1 = F.l1_loss(out, hr)
spec = stft_loss(out, hr)
# Main loss
loss = l1 + 0.25 * spec
# -------------------------------------------------
# Backprop
# -------------------------------------------------
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.parameters(),
1.0
)
opt.step()
step += 1
# -------------------------------------------------
# Logging
# -------------------------------------------------
if step % 10 == 0:
print(
f"Step {step} | "
f"L1 {l1.item():.4f} | "
f"STFT {spec.item():.4f} | "
f"Total {loss.item():.4f}"
)
# -------------------------------------------------
# Save checkpoint + validation
# -------------------------------------------------
if step % 1000 == 0:
ckpt_name = f"tinysr25_{step}.pt"
torch.save(
model.state_dict(),
ckpt_name
)
print(f"Saved {ckpt_name}")
# =============================================
# Validation
# =============================================
model.eval()
unseen, _ = sf.read(
"unseen.wav",
dtype="float32"
)
# Stereo -> mono
if len(unseen.shape) > 1:
unseen = unseen.mean(axis=1)
x = torch.from_numpy(
unseen
).unsqueeze(0).unsqueeze(0)
x = x.to(device)
with torch.no_grad():
y = model(x)
y = y.squeeze().cpu().numpy()
out_name = f"val_{step}.wav"
sf.write(
out_name,
y,
48000
)
print(f"Saved {out_name}")
model.train()
# -------------------------------------------------
# Finish
# -------------------------------------------------
if step >= max_steps:
break
if step >= max_steps:
break
print("Training complete!")
for 28 channels. For 12 channels it was mostly the same, but changed the number of channels.
For Claude 4.6 Sonnet's version:
"""
TinySR2 GAN Training Script
~52K parameters, optimized for Google Colab T4 GPU.
═══════════════════════════════════════════════
COLAB SETUP (run these cells first):
═══════════════════════════════════════════════
# 1. Install dependencies
!pip install torch torchaudio soundfile numpy auraloss
# 2. Download datasets
# FMA Small (~7.2GB):
!wget -c https://os.unil.cloud.switch.ch/fma/fma_small.zip
!unzip -q fma_small.zip -d ./audio
# FMA Metadata (for instrumental filter):
!wget -c https://os.unil.cloud.switch.ch/fma/fma_metadata.zip
!unzip -q fma_metadata.zip -d ./fma_metadata
# LJSpeech (~2.6GB):
!wget -c https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xjf LJSpeech-1.1.tar.bz2 -C ./audio
# FSD50K subset — download just Eval set (~10GB → ~4GB):
# https://zenodo.org/record/4060432
# Or skip and just use FMA + LJSpeech
# 3. Run this script:
!python train_tinysr2.py
"""
import os
import glob
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import soundfile as sf
from torch.utils.data import Dataset, DataLoader, ConcatDataset
# ─────────────────────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────────────────────
# Dataset directories — set to None to skip
FMA_DIR = "./audio/fma_small" # FMA Small MP3s
FMA_META = "./fma_metadata/tracks.csv" # FMA metadata for filtering
LJSPEECH_DIR = "./audio/LJSpeech-1.1/wavs" # LJSpeech WAVs
FSD50K_DIR = None # FSD50K eval set, or None to skip
SAVE_DIR = "./checkpoints"
SAVE_PATH = "./tinysr2.pt"
LOG_PATH = "./training_log.txt"
EPOCHS = 300
BATCH_SIZE = 64 # T4 16GB handles this easily at 52K params
LR = 2e-4
SEGMENT_SEC = 1.0
SR_HIGH = 48000
SR_LOW = 16000
CHANNELS = 64 # TinySR2: 2× wider than v1
BLOCKS = 8 # same depth as v1
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Target hours per category (will be balanced via sampling weights)
TARGET_HOURS = {
"music_instrumental": 2.0,
"music_vocals": 1.0,
"speech": 1.0,
"sfx": 0.5, # ignored if FSD50K_DIR is None
}
# Flip True once you have ffmpeg + libfdk_aac installed
SIMULATE_HEAAC = False
# ─────────────────────────────────────────────────────────────
# GENERATOR — TinySR2
# ─────────────────────────────────────────────────────────────
class TinySR2(nn.Module):
"""
TinySR2: 52K parameters
- ch=64, blocks=8 (vs ch=32, blocks=8 in v1)
- Processes at 16 kHz, pixel-shuffle upsample to 48 kHz
- Global residual skip for low-freq consistency
- SiLU activations
"""
def __init__(self, channels=CHANNELS, blocks=BLOCKS):
super().__init__()
self.input_conv = nn.Conv1d(1, channels, 9, padding=4)
self.blocks = nn.ModuleList([
nn.Sequential(
nn.Conv1d(channels, channels, 9, padding=4, groups=channels), # depthwise
nn.Conv1d(channels, channels, 1), # pointwise
nn.SiLU()
) for _ in range(blocks)
])
# Pixel shuffle ×3: output C*3 channels, reshape to ×3 time
self.upsample_conv = nn.Conv1d(channels, channels * 3, 1)
self.output_conv = nn.Conv1d(channels, 1, 9, padding=4)
def forward(self, x):
# Global skip: linear interpolation of input (handles low freqs for free)
residual = F.interpolate(x, scale_factor=3, mode="linear", align_corners=False)
x = self.input_conv(x)
for block in self.blocks:
x = x + block(x) # residual connection in each block
# Pixel shuffle upsample ×3
B, C, T = x.shape
x = self.upsample_conv(x) # (B, C*3, T)
x = x.view(B, C, 3, T).permute(0, 1, 3, 2).reshape(B, C, T * 3) # (B, C, T*3)
return self.output_conv(x) + residual
# ─────────────────────────────────────────────────────────────
# DISCRIMINATOR (training only, not deployed)
# ─────────────────────────────────────────────────────────────
class MultiResSTFTDiscriminator(nn.Module):
"""
3-resolution STFT discriminator.
~180K params — only used during training.
"""
def __init__(self):
super().__init__()
self.resolutions = [256, 512, 1024]
def make_stack():
return nn.Sequential(
nn.Conv2d(1, 32, (3, 3), padding=1),
nn.SiLU(),
nn.Conv2d(32, 64, (3, 3), stride=(2, 2), padding=1),
nn.SiLU(),
nn.Conv2d(64, 64, (3, 3), stride=(2, 2), padding=1),
nn.SiLU(),
nn.Conv2d(64, 32, (3, 3), stride=(2, 2), padding=1),
nn.SiLU(),
nn.AdaptiveAvgPool2d((4, 4)),
nn.Flatten(),
nn.Linear(32 * 16, 1)
)
self.stacks = nn.ModuleList([make_stack() for _ in self.resolutions])
def stft_mag(self, x, n_fft):
x = x.squeeze(1)
window = torch.hann_window(n_fft, device=x.device)
spec = torch.stft(x, n_fft=n_fft, hop_length=n_fft // 4,
win_length=n_fft, window=window, return_complex=True)
return (spec.abs() + 1e-8).unsqueeze(1)
def forward(self, x):
scores = [stack(self.stft_mag(x, n_fft))
for n_fft, stack in zip(self.resolutions, self.stacks)]
return torch.stack(scores, dim=1).mean(dim=1)
# ─────────────────────────────────────────────────────────────
# LOSSES
# ─────────────────────────────────────────────────────────────
def multi_res_stft_loss(pred, target, resolutions=[256, 512, 1024, 2048]):
"""Multi-resolution spectral convergence + log magnitude loss."""
total = 0.0
for n_fft in resolutions:
window = torch.hann_window(n_fft, device=pred.device)
def mag(x, n=n_fft, w=window):
s = torch.stft(x.squeeze(1), n_fft=n, hop_length=n // 4,
win_length=n, window=w, return_complex=True)
return s.abs() + 1e-8
p_mag = mag(pred)
t_mag = mag(target)
sc = torch.norm(t_mag - p_mag, "fro") / (torch.norm(t_mag, "fro") + 1e-8)
log = F.l1_loss(torch.log(p_mag), torch.log(t_mag))
total += sc + log
return total / len(resolutions)
def generator_loss(disc, fake):
return F.mse_loss(disc(fake), torch.ones(fake.size(0), 1, device=fake.device))
def discriminator_loss(disc, real, fake):
r = F.mse_loss(disc(real), torch.ones(real.size(0), 1, device=real.device))
f = F.mse_loss(disc(fake.detach()), torch.zeros(fake.size(0), 1, device=fake.device))
return (r + f) / 2
# ─────────────────────────────────────────────────────────────
# HE-AACv2 SIMULATION
# ─────────────────────────────────────────────────────────────
def simulate_heaac(audio_np):
"""
Encode through HE-AACv2 at 8 kbps (16 kHz mono) then decode.
Requires ffmpeg with libfdk_aac. Install on Colab:
!apt-get remove -y ffmpeg
# Download static ffmpeg build with fdk-aac:
!wget https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz
!tar -xf ffmpeg-release-amd64-static.tar.xz
!cp ffmpeg-*-static/ffmpeg /usr/local/bin/
# Note: johnvansickle builds include fdk-aac
"""
import subprocess, tempfile
seg_lo = int(SEGMENT_SEC * SR_LOW)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, audio_np, SR_HIGH)
in_path = f.name
out_aac = in_path.replace(".wav", ".aac")
out_wav = in_path.replace(".wav", "_dec.wav")
subprocess.run([
"ffmpeg", "-y", "-i", in_path,
"-c:a", "libfdk_aac", "-profile:a", "aac_he_v2",
"-b:a", "8k", "-ar", "16000", "-ac", "1",
out_aac
], capture_output=True)
subprocess.run([
"ffmpeg", "-y", "-i", out_aac,
"-ar", str(SR_LOW), out_wav
], capture_output=True)
degraded, _ = sf.read(out_wav, dtype="float32")
for p in [in_path, out_aac, out_wav]:
try: os.unlink(p)
except: pass
if len(degraded) < seg_lo:
degraded = np.pad(degraded, (0, seg_lo - len(degraded)))
return degraded[:seg_lo]
# ─────────────────────────────────────────────────────────────
# DATASET
# ─────────────────────────────────────────────────────────────
class AudioDataset(Dataset):
"""
Generic audio dataset. Loads files, crops to SEGMENT_SEC,
returns (low_res, high_res) pairs.
"""
def __init__(self, files, label="", crops=30, simulate_codec=False):
self.files = files
self.label = label
self.crops = crops
self.seg_hi = int(SEGMENT_SEC * SR_HIGH)
self.seg_lo = int(SEGMENT_SEC * SR_LOW)
self.simulate = simulate_codec
print(f" [{label}] {len(files)} files → {len(self):,} segments/epoch")
def __len__(self):
return len(self.files) * self.crops
def __getitem__(self, idx):
path = self.files[idx % len(self.files)]
try:
audio, sr = torchaudio.load(path)
audio = audio.mean(0, keepdim=True) # mono
if sr != SR_HIGH:
audio = torchaudio.functional.resample(audio, sr, SR_HIGH)
# Random crop
if audio.shape[1] < self.seg_hi:
audio = F.pad(audio, (0, self.seg_hi - audio.shape[1]))
start = random.randint(0, max(0, audio.shape[1] - self.seg_hi))
high = audio[:, start:start + self.seg_hi] # (1, T_hi)
# Low-res pair
if self.simulate:
low_np = simulate_heaac(high.squeeze().numpy())
low = torch.from_numpy(low_np).unsqueeze(0)
else:
low = torchaudio.functional.resample(high, SR_HIGH, SR_LOW)
# Normalize
peak = high.abs().max().clamp(min=1e-8)
return (low / peak).float(), (high / peak).float()
except Exception:
return (torch.zeros(1, self.seg_lo),
torch.zeros(1, self.seg_hi))
def get_fma_files(fma_dir, meta_path):
"""Split FMA files into instrumental and vocal lists using metadata."""
all_files = glob.glob(os.path.join(fma_dir, "**/*.mp3"), recursive=True)
# Try to load metadata for instrumental split
try:
import pandas as pd
tracks = pd.read_csv(meta_path, index_col=0, header=[0, 1], low_memory=False)
instrumental_ids = set()
vocal_ids = set()
for tid in tracks.index:
try:
is_inst = tracks.loc[tid, ('track', 'instrumental')]
if is_inst == True or is_inst == 1 or str(is_inst).lower() == 'true':
instrumental_ids.add(tid)
else:
vocal_ids.add(tid)
except Exception:
vocal_ids.add(tid)
def tid_from_path(p):
try: return int(os.path.splitext(os.path.basename(p))[0])
except: return -1
inst_files = [f for f in all_files if tid_from_path(f) in instrumental_ids]
vocal_files = [f for f in all_files if tid_from_path(f) in vocal_ids]
print(f" FMA instrumental: {len(inst_files)} files")
print(f" FMA vocal: {len(vocal_files)} files")
return inst_files, vocal_files
except Exception as e:
print(f" FMA metadata load failed ({e}), treating all as vocal")
return [], all_files
def build_datasets(simulate_codec=False):
"""Build and balance datasets according to TARGET_HOURS."""
print("\nBuilding datasets:")
datasets = []
# ── FMA ──
if FMA_DIR and os.path.exists(FMA_DIR):
inst_files, vocal_files = get_fma_files(FMA_DIR, FMA_META)
# How many files do we need for target hours?
# Each file is 30s, each crop is 1s, crops=30 → 30 crops × 30s / 3600 = 0.25h/file
def files_for_hours(hours, file_dur_sec=30):
segs_needed = int(hours * 3600 / SEGMENT_SEC)
return max(1, segs_needed // 30) # 30 crops per file
if inst_files:
n = min(len(inst_files), files_for_hours(TARGET_HOURS["music_instrumental"]))
random.shuffle(inst_files)
datasets.append(AudioDataset(inst_files[:n], "FMA instrumental",
simulate_codec=simulate_codec))
if vocal_files:
n = min(len(vocal_files), files_for_hours(TARGET_HOURS["music_vocals"]))
random.shuffle(vocal_files)
datasets.append(AudioDataset(vocal_files[:n], "FMA vocals",
simulate_codec=simulate_codec))
else:
print(" FMA not found — skipping")
# ── LJSpeech ──
if LJSPEECH_DIR and os.path.exists(LJSPEECH_DIR):
speech_files = glob.glob(os.path.join(LJSPEECH_DIR, "*.wav"))
random.shuffle(speech_files)
# LJSpeech files are ~5s each; target 1h = 720 files
n = min(len(speech_files), 720)
datasets.append(AudioDataset(speech_files[:n], "LJSpeech",
crops=12, # shorter files → fewer crops
simulate_codec=simulate_codec))
else:
print(" LJSpeech not found — skipping")
# ── FSD50K ──
if FSD50K_DIR and os.path.exists(FSD50K_DIR):
sfx_files = glob.glob(os.path.join(FSD50K_DIR, "**/*.wav"), recursive=True)
random.shuffle(sfx_files)
n = min(len(sfx_files), 300)
datasets.append(AudioDataset(sfx_files[:n], "FSD50K",
crops=8,
simulate_codec=simulate_codec))
else:
print(" FSD50K not found — skipping")
assert len(datasets) > 0, "No datasets found! Check your paths."
combined = ConcatDataset(datasets)
print(f"\nTotal: {len(combined):,} segments/epoch\n")
return combined
# ─────────────────────────────────────────────────────────────
# TRAINING
# ─────────────────────────────────────────────────────────────
def train():
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"Device: {DEVICE}")
if DEVICE == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
dataset = build_datasets(simulate_codec=SIMULATE_HEAAC)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=(DEVICE == "cuda"),
drop_last=True)
G = TinySR2().to(DEVICE)
D = MultiResSTFTDiscriminator().to(DEVICE)
opt_G = torch.optim.AdamW(G.parameters(), lr=LR, betas=(0.8, 0.99))
opt_D = torch.optim.AdamW(D.parameters(), lr=LR * 0.5, betas=(0.8, 0.99))
sched_G = torch.optim.lr_scheduler.CosineAnnealingLR(opt_G, T_max=EPOCHS)
sched_D = torch.optim.lr_scheduler.CosineAnnealingLR(opt_D, T_max=EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))
g_params = sum(p.numel() for p in G.parameters())
d_params = sum(p.numel() for p in D.parameters())
print(f"TinySR2 generator: {g_params:,} params")
print(f"Discriminator: {d_params:,} params (training only)")
print(f"Batch size: {BATCH_SIZE}")
print(f"Starting {EPOCHS} epochs...\n")
log = open(LOG_PATH, "w")
log.write("epoch,g_loss,d_loss,spec_loss,lr,time_s\n")
best_spec = float("inf")
for epoch in range(1, EPOCHS + 1):
G.train(); D.train()
g_losses, d_losses, spec_losses = [], [], []
t0 = time.time()
for low, high in loader:
low = low.to(DEVICE)
high = high.to(DEVICE)
# ── Discriminator step ──
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
fake = G(low)
min_t = min(fake.shape[-1], high.shape[-1])
fake = fake[..., :min_t]
real = high[..., :min_t]
loss_D = discriminator_loss(D, real, fake)
opt_D.zero_grad()
scaler.scale(loss_D).backward()
torch.nn.utils.clip_grad_norm_(D.parameters(), 1.0)
scaler.step(opt_D)
# ── Generator step ──
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
fake = G(low)
fake = fake[..., :min_t]
loss_spec = multi_res_stft_loss(fake, real)
loss_l1 = F.l1_loss(fake, real)
loss_adv = generator_loss(D, fake)
# Weights: spectral + L1 dominate early, adversarial pushes crispness
loss_G = loss_spec * 1.0 + loss_l1 * 10.0 + loss_adv * 0.1
opt_G.zero_grad()
scaler.scale(loss_G).backward()
torch.nn.utils.clip_grad_norm_(G.parameters(), 1.0)
scaler.step(opt_G)
scaler.update()
g_losses.append(loss_G.item())
d_losses.append(loss_D.item())
spec_losses.append(loss_spec.item())
sched_G.step()
sched_D.step()
elapsed = time.time() - t0
mean_g = np.mean(g_losses)
mean_d = np.mean(d_losses)
mean_spec = np.mean(spec_losses)
cur_lr = sched_G.get_last_lr()[0]
# Log
log.write(f"{epoch},{mean_g:.5f},{mean_d:.5f},{mean_spec:.5f},{cur_lr:.6f},{elapsed:.1f}\n")
log.flush()
if epoch % 10 == 0 or epoch == 1:
print(f"Epoch {epoch:3d}/{EPOCHS} | "
f"G: {mean_g:.4f} | D: {mean_d:.4f} | "
f"Spec: {mean_spec:.4f} | "
f"LR: {cur_lr:.2e} | {elapsed:.0f}s")
# Save best
if mean_spec < best_spec:
best_spec = mean_spec
torch.save(G.state_dict(), os.path.join(SAVE_DIR, "tinysr2_best.pt"))
# Periodic checkpoints
if epoch % 50 == 0:
ckpt = os.path.join(SAVE_DIR, f"tinysr2_ep{epoch}.pt")
torch.save(G.state_dict(), ckpt)
print(f" → Checkpoint: {ckpt}")
# Save final
torch.save(G.state_dict(), SAVE_PATH)
size_kb = os.path.getsize(SAVE_PATH) / 1024
print(f"\nDone! {SAVE_PATH} ({size_kb:.1f} KB float32)")
# Save float16 version (~half size)
G_half = TinySR2()
G_half.load_state_dict(torch.load(SAVE_PATH, map_location="cpu"))
G_half = G_half.half()
fp16_path = SAVE_PATH.replace(".pt", "_fp16.pt")
torch.save(G_half.state_dict(), fp16_path)
print(f"fp16: {fp16_path} ({os.path.getsize(fp16_path) / 1024:.1f} KB)")
log.close()
if __name__ == "__main__":
train()
Benchmark Comparison
| Model | By Myself? | Size |
|---|---|---|
| TinySR 2 (MihaiPopa-1/TinySR) | ✅ | 34.3KB |
| NovaSR (YatharthS/NovaSR) | ❌ | 52KB |
| TinySR 1 (MihaiPopa-1/TinySR) | ✅ | 57.6KB |
| FlashSR (YatharthS/FlashSR) | ❌ | 2MB (PyTorch) / 500KB (ONNX) |
| AudioSR (haoheliu/audiosr_basic) | ❌ | 6.18GB |
Other things
I used Odetari - Keep Up (HE-AACv2 version encoded at 8 kbps, 16 kHz, stereo, FDK-AAC) and the model worked!
It's not exactly like the ground truth (and not perfect), but it's better than I thought!