TinySR 2 Readme!

TinySR 2 is the world's smallest super-resolution model (for audio) that I trained!

It's only 34.3 KB (for the GPT-5.3 Mini / 28-channel version) in size and it can easily run on a CPU or GPU!

There are other checkpoints as well, a version by Claude 4.6 Sonnet (larger at 216 KB with GAN) and a 12-channel version (14.3 KB, trained for ~380 steps)

Details

  • Input Sample Rate: 16 kHz
  • Output Sample Rate: 48 kHz
  • Size: Only 34.3 KB

The model works on mono (and stereo, by upscaling both channels) audio!

Variants

File Name Description Architecture Audio Used Steps/Epochs Trained Size
TinySR 2 - Run 2 - 2500 Steps.pt GPT-5.3 Mini version, 28 channels 28-channel Conv1D, SiLU (was ReLU on TinySR 1) 77 minutes of AI audio* Run 2, 2500 steps 34.3KB
TinySR 2 - Run 2 - 2000 Steps.pt GPT-5.3 Mini version, 28 channels 28-channel Conv1D, SiLU 77 minutes of AI audio* Run 2, 2000 steps 32.3KB
TinySR 2 - Run 2 - 1000 Steps.pt GPT-5.3 Mini version, 28 channels 28-channel Conv1D, SiLU 77 minutes of AI audio* Run 2, 1000 steps 32.3KB
TinySR 2 - Run 1 - 1000 Steps.pt GPT-5.3 Mini version, 28 channels 28-channel Conv1D, SiLU 77 minutes of AI audio* Run 1, 1000 steps 32.3KB
TinySR 2 - 12 Channels - 380 Steps.pt GPT-5.3 Mini version, 12 channels 12-channel Conv1D, SiLU 77 minutes of AI audio* 380 steps 14.3KB
TinySR 2 - Claude 4.6 Sonnet.pt Claude 4.6 Sonnet version 64-channel, 8-block Conv1D, SiLU, GAN, pixel-shuffle upsampling FMA Small + LJSpeech 12 epochs 216KB

*it was a mix of speech (ListenHub, Google's NotebookLM Audio Overviews) + music (Suno 3.5, Suno 4.5-all, Mureka 8, Mubert)

Code

For the GPT-5.3 Mini version:

# =========================================================
# TinySR 2.5
# ~8K parameter audio super-resolution model
# 16 kHz -> 48 kHz
# Optimized for:
# - Tiny size
# - Better receptive field
# - Fast inference
# - Better perceptual quality
# =========================================================

import os
import random
import torch
import soundfile as sf
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

# =========================================================
# 1️⃣ Residual Dilated Block
# =========================================================

class TinyBlock(nn.Module):
    def __init__(self, channels, dilation):
        super().__init__()

        self.dw = nn.Conv1d(
            channels,
            channels,
            kernel_size=5,
            padding=dilation * 2,
            dilation=dilation,
            groups=channels
        )

        self.pw = nn.Conv1d(
            channels,
            channels,
            kernel_size=1
        )

        self.act = nn.SiLU()

    def forward(self, x):

        y = self.dw(x)
        y = self.pw(y)
        y = self.act(y)

        return x + y


# =========================================================
# 2️⃣ TinySR 2.5 Model
# =========================================================

class TinySR25(nn.Module):
    def __init__(self, channels=28):
        super().__init__()

        self.input = nn.Conv1d(
            1,
            channels,
            kernel_size=1
        )

        # Massive receptive field with tiny params
        dilations = [1, 2, 4, 8, 16, 32]

        self.blocks = nn.ModuleList([
            TinyBlock(channels, d)
            for d in dilations
        ])

        self.output = nn.Conv1d(
            channels,
            1,
            kernel_size=1
        )

    def forward(self, x):

        # =================================================
        # Linear interpolation baseline
        # 16k -> 48k
        # =================================================

        base = F.interpolate(
            x,
            scale_factor=3,
            mode="linear",
            align_corners=False
        )

        y = self.input(base)

        for block in self.blocks:
            y = block(y)

        residual = self.output(y)

        # Residual prediction
        return base + residual


# =========================================================
# 3️⃣ Dataset Loader
# =========================================================

class AudioPairDataset(Dataset):

    def __init__(
        self,
        lr_dir,
        hr_dir,
        segment_seconds=1.0
    ):

        self.lr_dir = lr_dir
        self.hr_dir = hr_dir

        self.files = sorted(os.listdir(lr_dir))

        # 16 kHz input
        self.seg_len_lr = int(
            16000 * segment_seconds
        )

        # 48 kHz target
        self.seg_len_hr = self.seg_len_lr * 3

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):

        name = self.files[idx]

        # -------------------------------------------------
        # Load low-resolution audio
        # -------------------------------------------------

        lr, _ = sf.read(
            os.path.join(self.lr_dir, name),
            dtype="float32"
        )

        # -------------------------------------------------
        # Load high-resolution audio
        # -------------------------------------------------

        hr, _ = sf.read(
            os.path.join(self.hr_dir, name),
            dtype="float32"
        )

        # -------------------------------------------------
        # Convert stereo -> mono if needed
        # -------------------------------------------------

        if len(lr.shape) > 1:
            lr = lr.mean(axis=1)

        if len(hr.shape) > 1:
            hr = hr.mean(axis=1)

        # -------------------------------------------------
        # Random crop
        # -------------------------------------------------

        if len(lr) > self.seg_len_lr:

            start = random.randint(
                0,
                len(lr) - self.seg_len_lr
            )

            lr = lr[
                start:
                start + self.seg_len_lr
            ]

            hr = hr[
                start * 3:
                (start * 3) + self.seg_len_hr
            ]

        return (
            torch.from_numpy(lr).unsqueeze(0),
            torch.from_numpy(hr).unsqueeze(0)
        )


# =========================================================
# 4️⃣ Multi-Resolution STFT Loss
# =========================================================

def stft_loss(x, y):

    loss = 0.0

    fft_sizes = [256, 512, 1024]

    for fft in fft_sizes:

        window = torch.hann_window(
            fft,
            device=x.device
        )

        X = torch.stft(
            x.squeeze(1),
            n_fft=fft,
            hop_length=fft // 4,
            window=window,
            return_complex=True
        )

        Y = torch.stft(
            y.squeeze(1),
            n_fft=fft,
            hop_length=fft // 4,
            window=window,
            return_complex=True
        )

        loss += (
            X.abs() - Y.abs()
        ).abs().mean()

    return loss / len(fft_sizes)


# =========================================================
# 5️⃣ Setup
# =========================================================

device = (
    "cuda"
    if torch.cuda.is_available()
    else "cpu"
)

print("Using device:", device)

model = TinySR25(
    channels=28
).to(device)

# ---------------------------------------------------------
# Parameter count
# ---------------------------------------------------------

params = sum(
    p.numel()
    for p in model.parameters()
)

print(f"Parameters: {params}")

# ---------------------------------------------------------
# Optimizer
# ---------------------------------------------------------

opt = torch.optim.AdamW(
    model.parameters(),
    lr=2e-4
)

# =========================================================
# 6️⃣ Dataset
# =========================================================

dataset = AudioPairDataset(
    lr_dir="data/lr",
    hr_dir="data/hr",
    segment_seconds=1.0
)

loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

# =========================================================
# 7️⃣ Training Loop
# =========================================================

step = 0
max_steps = 100000

model.train()

while True:

    for lr, hr in loader:

        lr = lr.to(device)
        hr = hr.to(device)

        # -------------------------------------------------
        # Forward
        # -------------------------------------------------

        opt.zero_grad()

        out = model(lr)

        # -------------------------------------------------
        # Losses
        # -------------------------------------------------

        l1 = F.l1_loss(out, hr)

        spec = stft_loss(out, hr)

        # Main loss
        loss = l1 + 0.25 * spec

        # -------------------------------------------------
        # Backprop
        # -------------------------------------------------

        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            1.0
        )

        opt.step()

        step += 1

        # -------------------------------------------------
        # Logging
        # -------------------------------------------------

        if step % 10 == 0:

            print(
                f"Step {step} | "
                f"L1 {l1.item():.4f} | "
                f"STFT {spec.item():.4f} | "
                f"Total {loss.item():.4f}"
            )

        # -------------------------------------------------
        # Save checkpoint + validation
        # -------------------------------------------------

        if step % 1000 == 0:

            ckpt_name = f"tinysr25_{step}.pt"

            torch.save(
                model.state_dict(),
                ckpt_name
            )

            print(f"Saved {ckpt_name}")

            # =============================================
            # Validation
            # =============================================

            model.eval()

            unseen, _ = sf.read(
                "unseen.wav",
                dtype="float32"
            )

            # Stereo -> mono
            if len(unseen.shape) > 1:
                unseen = unseen.mean(axis=1)

            x = torch.from_numpy(
                unseen
            ).unsqueeze(0).unsqueeze(0)

            x = x.to(device)

            with torch.no_grad():
                y = model(x)

            y = y.squeeze().cpu().numpy()

            out_name = f"val_{step}.wav"

            sf.write(
                out_name,
                y,
                48000
            )

            print(f"Saved {out_name}")

            model.train()

        # -------------------------------------------------
        # Finish
        # -------------------------------------------------

        if step >= max_steps:
            break

    if step >= max_steps:
        break

print("Training complete!")

for 28 channels. For 12 channels it was mostly the same, but changed the number of channels.

For Claude 4.6 Sonnet's version:

"""
TinySR2 GAN Training Script
~52K parameters, optimized for Google Colab T4 GPU.

═══════════════════════════════════════════════
COLAB SETUP (run these cells first):
═══════════════════════════════════════════════

# 1. Install dependencies
!pip install torch torchaudio soundfile numpy auraloss

# 2. Download datasets
# FMA Small (~7.2GB):
!wget -c https://os.unil.cloud.switch.ch/fma/fma_small.zip
!unzip -q fma_small.zip -d ./audio

# FMA Metadata (for instrumental filter):
!wget -c https://os.unil.cloud.switch.ch/fma/fma_metadata.zip
!unzip -q fma_metadata.zip -d ./fma_metadata

# LJSpeech (~2.6GB):
!wget -c https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xjf LJSpeech-1.1.tar.bz2 -C ./audio

# FSD50K subset — download just Eval set (~10GB → ~4GB):
# https://zenodo.org/record/4060432
# Or skip and just use FMA + LJSpeech

# 3. Run this script:
!python train_tinysr2.py
"""

import os
import glob
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import soundfile as sf
from torch.utils.data import Dataset, DataLoader, ConcatDataset

# ─────────────────────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────────────────────
# Dataset directories — set to None to skip
FMA_DIR       = "./audio/fma_small"          # FMA Small MP3s
FMA_META      = "./fma_metadata/tracks.csv"  # FMA metadata for filtering
LJSPEECH_DIR  = "./audio/LJSpeech-1.1/wavs" # LJSpeech WAVs
FSD50K_DIR    = None                         # FSD50K eval set, or None to skip

SAVE_DIR      = "./checkpoints"
SAVE_PATH     = "./tinysr2.pt"
LOG_PATH      = "./training_log.txt"

EPOCHS        = 300
BATCH_SIZE    = 64                # T4 16GB handles this easily at 52K params
LR            = 2e-4
SEGMENT_SEC   = 1.0
SR_HIGH       = 48000
SR_LOW        = 16000
CHANNELS      = 64                # TinySR2: 2× wider than v1
BLOCKS        = 8                 # same depth as v1
NUM_WORKERS   = 2
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

# Target hours per category (will be balanced via sampling weights)
TARGET_HOURS = {
    "music_instrumental": 2.0,
    "music_vocals":       1.0,
    "speech":             1.0,
    "sfx":                0.5,    # ignored if FSD50K_DIR is None
}

# Flip True once you have ffmpeg + libfdk_aac installed
SIMULATE_HEAAC = False


# ─────────────────────────────────────────────────────────────
# GENERATOR — TinySR2
# ─────────────────────────────────────────────────────────────
class TinySR2(nn.Module):
    """
    TinySR2: 52K parameters
    - ch=64, blocks=8 (vs ch=32, blocks=8 in v1)
    - Processes at 16 kHz, pixel-shuffle upsample to 48 kHz
    - Global residual skip for low-freq consistency
    - SiLU activations
    """
    def __init__(self, channels=CHANNELS, blocks=BLOCKS):
        super().__init__()
        self.input_conv = nn.Conv1d(1, channels, 9, padding=4)

        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(channels, channels, 9, padding=4, groups=channels),  # depthwise
                nn.Conv1d(channels, channels, 1),                               # pointwise
                nn.SiLU()
            ) for _ in range(blocks)
        ])

        # Pixel shuffle ×3: output C*3 channels, reshape to ×3 time
        self.upsample_conv = nn.Conv1d(channels, channels * 3, 1)
        self.output_conv   = nn.Conv1d(channels, 1, 9, padding=4)

    def forward(self, x):
        # Global skip: linear interpolation of input (handles low freqs for free)
        residual = F.interpolate(x, scale_factor=3, mode="linear", align_corners=False)

        x = self.input_conv(x)
        for block in self.blocks:
            x = x + block(x)          # residual connection in each block

        # Pixel shuffle upsample ×3
        B, C, T = x.shape
        x = self.upsample_conv(x)                                    # (B, C*3, T)
        x = x.view(B, C, 3, T).permute(0, 1, 3, 2).reshape(B, C, T * 3)  # (B, C, T*3)

        return self.output_conv(x) + residual


# ─────────────────────────────────────────────────────────────
# DISCRIMINATOR (training only, not deployed)
# ─────────────────────────────────────────────────────────────
class MultiResSTFTDiscriminator(nn.Module):
    """
    3-resolution STFT discriminator.
    ~180K params — only used during training.
    """
    def __init__(self):
        super().__init__()
        self.resolutions = [256, 512, 1024]

        def make_stack():
            return nn.Sequential(
                nn.Conv2d(1, 32, (3, 3), padding=1),
                nn.SiLU(),
                nn.Conv2d(32, 64, (3, 3), stride=(2, 2), padding=1),
                nn.SiLU(),
                nn.Conv2d(64, 64, (3, 3), stride=(2, 2), padding=1),
                nn.SiLU(),
                nn.Conv2d(64, 32, (3, 3), stride=(2, 2), padding=1),
                nn.SiLU(),
                nn.AdaptiveAvgPool2d((4, 4)),
                nn.Flatten(),
                nn.Linear(32 * 16, 1)
            )

        self.stacks = nn.ModuleList([make_stack() for _ in self.resolutions])

    def stft_mag(self, x, n_fft):
        x = x.squeeze(1)
        window = torch.hann_window(n_fft, device=x.device)
        spec = torch.stft(x, n_fft=n_fft, hop_length=n_fft // 4,
                          win_length=n_fft, window=window, return_complex=True)
        return (spec.abs() + 1e-8).unsqueeze(1)

    def forward(self, x):
        scores = [stack(self.stft_mag(x, n_fft))
                  for n_fft, stack in zip(self.resolutions, self.stacks)]
        return torch.stack(scores, dim=1).mean(dim=1)


# ─────────────────────────────────────────────────────────────
# LOSSES
# ─────────────────────────────────────────────────────────────
def multi_res_stft_loss(pred, target, resolutions=[256, 512, 1024, 2048]):
    """Multi-resolution spectral convergence + log magnitude loss."""
    total = 0.0
    for n_fft in resolutions:
        window = torch.hann_window(n_fft, device=pred.device)
        def mag(x, n=n_fft, w=window):
            s = torch.stft(x.squeeze(1), n_fft=n, hop_length=n // 4,
                           win_length=n, window=w, return_complex=True)
            return s.abs() + 1e-8
        p_mag = mag(pred)
        t_mag = mag(target)
        sc   = torch.norm(t_mag - p_mag, "fro") / (torch.norm(t_mag, "fro") + 1e-8)
        log  = F.l1_loss(torch.log(p_mag), torch.log(t_mag))
        total += sc + log
    return total / len(resolutions)


def generator_loss(disc, fake):
    return F.mse_loss(disc(fake), torch.ones(fake.size(0), 1, device=fake.device))


def discriminator_loss(disc, real, fake):
    r = F.mse_loss(disc(real),         torch.ones(real.size(0), 1, device=real.device))
    f = F.mse_loss(disc(fake.detach()), torch.zeros(fake.size(0), 1, device=fake.device))
    return (r + f) / 2


# ─────────────────────────────────────────────────────────────
# HE-AACv2 SIMULATION
# ─────────────────────────────────────────────────────────────
def simulate_heaac(audio_np):
    """
    Encode through HE-AACv2 at 8 kbps (16 kHz mono) then decode.
    Requires ffmpeg with libfdk_aac. Install on Colab:

      !apt-get remove -y ffmpeg
      # Download static ffmpeg build with fdk-aac:
      !wget https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz
      !tar -xf ffmpeg-release-amd64-static.tar.xz
      !cp ffmpeg-*-static/ffmpeg /usr/local/bin/
      # Note: johnvansickle builds include fdk-aac
    """
    import subprocess, tempfile
    seg_lo = int(SEGMENT_SEC * SR_LOW)

    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
        sf.write(f.name, audio_np, SR_HIGH)
        in_path = f.name

    out_aac = in_path.replace(".wav", ".aac")
    out_wav = in_path.replace(".wav", "_dec.wav")

    subprocess.run([
        "ffmpeg", "-y", "-i", in_path,
        "-c:a", "libfdk_aac", "-profile:a", "aac_he_v2",
        "-b:a", "8k", "-ar", "16000", "-ac", "1",
        out_aac
    ], capture_output=True)

    subprocess.run([
        "ffmpeg", "-y", "-i", out_aac,
        "-ar", str(SR_LOW), out_wav
    ], capture_output=True)

    degraded, _ = sf.read(out_wav, dtype="float32")
    for p in [in_path, out_aac, out_wav]:
        try: os.unlink(p)
        except: pass

    if len(degraded) < seg_lo:
        degraded = np.pad(degraded, (0, seg_lo - len(degraded)))
    return degraded[:seg_lo]


# ─────────────────────────────────────────────────────────────
# DATASET
# ─────────────────────────────────────────────────────────────
class AudioDataset(Dataset):
    """
    Generic audio dataset. Loads files, crops to SEGMENT_SEC,
    returns (low_res, high_res) pairs.
    """
    def __init__(self, files, label="", crops=30, simulate_codec=False):
        self.files    = files
        self.label    = label
        self.crops    = crops
        self.seg_hi   = int(SEGMENT_SEC * SR_HIGH)
        self.seg_lo   = int(SEGMENT_SEC * SR_LOW)
        self.simulate = simulate_codec
        print(f"  [{label}] {len(files)} files → {len(self):,} segments/epoch")

    def __len__(self):
        return len(self.files) * self.crops

    def __getitem__(self, idx):
        path = self.files[idx % len(self.files)]
        try:
            audio, sr = torchaudio.load(path)
            audio = audio.mean(0, keepdim=True)                # mono
            if sr != SR_HIGH:
                audio = torchaudio.functional.resample(audio, sr, SR_HIGH)

            # Random crop
            if audio.shape[1] < self.seg_hi:
                audio = F.pad(audio, (0, self.seg_hi - audio.shape[1]))
            start = random.randint(0, max(0, audio.shape[1] - self.seg_hi))
            high  = audio[:, start:start + self.seg_hi]        # (1, T_hi)

            # Low-res pair
            if self.simulate:
                low_np = simulate_heaac(high.squeeze().numpy())
                low    = torch.from_numpy(low_np).unsqueeze(0)
            else:
                low = torchaudio.functional.resample(high, SR_HIGH, SR_LOW)

            # Normalize
            peak = high.abs().max().clamp(min=1e-8)
            return (low / peak).float(), (high / peak).float()

        except Exception:
            return (torch.zeros(1, self.seg_lo),
                    torch.zeros(1, self.seg_hi))


def get_fma_files(fma_dir, meta_path):
    """Split FMA files into instrumental and vocal lists using metadata."""
    all_files = glob.glob(os.path.join(fma_dir, "**/*.mp3"), recursive=True)

    # Try to load metadata for instrumental split
    try:
        import pandas as pd
        tracks = pd.read_csv(meta_path, index_col=0, header=[0, 1], low_memory=False)

        instrumental_ids = set()
        vocal_ids        = set()

        for tid in tracks.index:
            try:
                is_inst = tracks.loc[tid, ('track', 'instrumental')]
                if is_inst == True or is_inst == 1 or str(is_inst).lower() == 'true':
                    instrumental_ids.add(tid)
                else:
                    vocal_ids.add(tid)
            except Exception:
                vocal_ids.add(tid)

        def tid_from_path(p):
            try: return int(os.path.splitext(os.path.basename(p))[0])
            except: return -1

        inst_files  = [f for f in all_files if tid_from_path(f) in instrumental_ids]
        vocal_files = [f for f in all_files if tid_from_path(f) in vocal_ids]

        print(f"  FMA instrumental: {len(inst_files)} files")
        print(f"  FMA vocal:        {len(vocal_files)} files")
        return inst_files, vocal_files

    except Exception as e:
        print(f"  FMA metadata load failed ({e}), treating all as vocal")
        return [], all_files


def build_datasets(simulate_codec=False):
    """Build and balance datasets according to TARGET_HOURS."""
    print("\nBuilding datasets:")
    datasets = []

    # ── FMA ──
    if FMA_DIR and os.path.exists(FMA_DIR):
        inst_files, vocal_files = get_fma_files(FMA_DIR, FMA_META)

        # How many files do we need for target hours?
        # Each file is 30s, each crop is 1s, crops=30 → 30 crops × 30s / 3600 = 0.25h/file
        def files_for_hours(hours, file_dur_sec=30):
            segs_needed = int(hours * 3600 / SEGMENT_SEC)
            return max(1, segs_needed // 30)  # 30 crops per file

        if inst_files:
            n = min(len(inst_files), files_for_hours(TARGET_HOURS["music_instrumental"]))
            random.shuffle(inst_files)
            datasets.append(AudioDataset(inst_files[:n], "FMA instrumental",
                                         simulate_codec=simulate_codec))

        if vocal_files:
            n = min(len(vocal_files), files_for_hours(TARGET_HOURS["music_vocals"]))
            random.shuffle(vocal_files)
            datasets.append(AudioDataset(vocal_files[:n], "FMA vocals",
                                         simulate_codec=simulate_codec))
    else:
        print("  FMA not found — skipping")

    # ── LJSpeech ──
    if LJSPEECH_DIR and os.path.exists(LJSPEECH_DIR):
        speech_files = glob.glob(os.path.join(LJSPEECH_DIR, "*.wav"))
        random.shuffle(speech_files)
        # LJSpeech files are ~5s each; target 1h = 720 files
        n = min(len(speech_files), 720)
        datasets.append(AudioDataset(speech_files[:n], "LJSpeech",
                                     crops=12,           # shorter files → fewer crops
                                     simulate_codec=simulate_codec))
    else:
        print("  LJSpeech not found — skipping")

    # ── FSD50K ──
    if FSD50K_DIR and os.path.exists(FSD50K_DIR):
        sfx_files = glob.glob(os.path.join(FSD50K_DIR, "**/*.wav"), recursive=True)
        random.shuffle(sfx_files)
        n = min(len(sfx_files), 300)
        datasets.append(AudioDataset(sfx_files[:n], "FSD50K",
                                     crops=8,
                                     simulate_codec=simulate_codec))
    else:
        print("  FSD50K not found — skipping")

    assert len(datasets) > 0, "No datasets found! Check your paths."
    combined = ConcatDataset(datasets)
    print(f"\nTotal: {len(combined):,} segments/epoch\n")
    return combined


# ─────────────────────────────────────────────────────────────
# TRAINING
# ─────────────────────────────────────────────────────────────
def train():
    os.makedirs(SAVE_DIR, exist_ok=True)

    print(f"Device: {DEVICE}")
    if DEVICE == "cuda":
        print(f"GPU:    {torch.cuda.get_device_name(0)}")
        print(f"VRAM:   {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    dataset = build_datasets(simulate_codec=SIMULATE_HEAAC)
    loader  = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
                         num_workers=NUM_WORKERS, pin_memory=(DEVICE == "cuda"),
                         drop_last=True)

    G = TinySR2().to(DEVICE)
    D = MultiResSTFTDiscriminator().to(DEVICE)

    opt_G = torch.optim.AdamW(G.parameters(), lr=LR,       betas=(0.8, 0.99))
    opt_D = torch.optim.AdamW(D.parameters(), lr=LR * 0.5, betas=(0.8, 0.99))

    sched_G = torch.optim.lr_scheduler.CosineAnnealingLR(opt_G, T_max=EPOCHS)
    sched_D = torch.optim.lr_scheduler.CosineAnnealingLR(opt_D, T_max=EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))

    g_params = sum(p.numel() for p in G.parameters())
    d_params = sum(p.numel() for p in D.parameters())
    print(f"TinySR2 generator:     {g_params:,} params")
    print(f"Discriminator:         {d_params:,} params (training only)")
    print(f"Batch size:            {BATCH_SIZE}")
    print(f"Starting {EPOCHS} epochs...\n")

    log = open(LOG_PATH, "w")
    log.write("epoch,g_loss,d_loss,spec_loss,lr,time_s\n")

    best_spec = float("inf")

    for epoch in range(1, EPOCHS + 1):
        G.train(); D.train()
        g_losses, d_losses, spec_losses = [], [], []
        t0 = time.time()

        for low, high in loader:
            low  = low.to(DEVICE)
            high = high.to(DEVICE)

            # ── Discriminator step ──
            with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
                fake  = G(low)
                min_t = min(fake.shape[-1], high.shape[-1])
                fake  = fake[..., :min_t]
                real  = high[..., :min_t]
                loss_D = discriminator_loss(D, real, fake)

            opt_D.zero_grad()
            scaler.scale(loss_D).backward()
            torch.nn.utils.clip_grad_norm_(D.parameters(), 1.0)
            scaler.step(opt_D)

            # ── Generator step ──
            with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
                fake      = G(low)
                fake      = fake[..., :min_t]
                loss_spec = multi_res_stft_loss(fake, real)
                loss_l1   = F.l1_loss(fake, real)
                loss_adv  = generator_loss(D, fake)
                # Weights: spectral + L1 dominate early, adversarial pushes crispness
                loss_G = loss_spec * 1.0 + loss_l1 * 10.0 + loss_adv * 0.1

            opt_G.zero_grad()
            scaler.scale(loss_G).backward()
            torch.nn.utils.clip_grad_norm_(G.parameters(), 1.0)
            scaler.step(opt_G)
            scaler.update()

            g_losses.append(loss_G.item())
            d_losses.append(loss_D.item())
            spec_losses.append(loss_spec.item())

        sched_G.step()
        sched_D.step()

        elapsed   = time.time() - t0
        mean_g    = np.mean(g_losses)
        mean_d    = np.mean(d_losses)
        mean_spec = np.mean(spec_losses)
        cur_lr    = sched_G.get_last_lr()[0]

        # Log
        log.write(f"{epoch},{mean_g:.5f},{mean_d:.5f},{mean_spec:.5f},{cur_lr:.6f},{elapsed:.1f}\n")
        log.flush()

        if epoch % 10 == 0 or epoch == 1:
            print(f"Epoch {epoch:3d}/{EPOCHS} | "
                  f"G: {mean_g:.4f} | D: {mean_d:.4f} | "
                  f"Spec: {mean_spec:.4f} | "
                  f"LR: {cur_lr:.2e} | {elapsed:.0f}s")

        # Save best
        if mean_spec < best_spec:
            best_spec = mean_spec
            torch.save(G.state_dict(), os.path.join(SAVE_DIR, "tinysr2_best.pt"))

        # Periodic checkpoints
        if epoch % 50 == 0:
            ckpt = os.path.join(SAVE_DIR, f"tinysr2_ep{epoch}.pt")
            torch.save(G.state_dict(), ckpt)
            print(f"  → Checkpoint: {ckpt}")

    # Save final
    torch.save(G.state_dict(), SAVE_PATH)
    size_kb = os.path.getsize(SAVE_PATH) / 1024
    print(f"\nDone! {SAVE_PATH} ({size_kb:.1f} KB float32)")

    # Save float16 version (~half size)
    G_half = TinySR2()
    G_half.load_state_dict(torch.load(SAVE_PATH, map_location="cpu"))
    G_half = G_half.half()
    fp16_path = SAVE_PATH.replace(".pt", "_fp16.pt")
    torch.save(G_half.state_dict(), fp16_path)
    print(f"fp16:  {fp16_path} ({os.path.getsize(fp16_path) / 1024:.1f} KB)")

    log.close()


if __name__ == "__main__":
    train()

Benchmark Comparison

Model By Myself? Size
TinySR 2 (MihaiPopa-1/TinySR) 34.3KB
NovaSR (YatharthS/NovaSR) 52KB
TinySR 1 (MihaiPopa-1/TinySR) 57.6KB
FlashSR (YatharthS/FlashSR) 2MB (PyTorch) / 500KB (ONNX)
AudioSR (haoheliu/audiosr_basic) 6.18GB

Other things

I used Odetari - Keep Up (HE-AACv2 version encoded at 8 kbps, 16 kHz, stereo, FDK-AAC) and the model worked!

It's not exactly like the ground truth (and not perfect), but it's better than I thought!

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support
Free AI Image Generator No sign-up. Instant results. Open Now