AuriStream-1B

📚 Paper - 🌐 Project Page

AuriStream is a biologically-inspired, GPT-style autoregressive Transformer trained to predict tokens from the speech stream (denoted as cochlear tokens). These cochlear tokens are discrete codes produced by a companion “WavCoch” tokenizer (a model trained to predict the time-frequency cochleagram from a waveform, with a LFQ bottleneck for token read-out). AuriStream utilizes a long context window of (20 s, ~4096 tokens) and is trained on **LibriLight (60k hours)** for 500k steps. It learns meaningful representations about e.g. phoneme/word identity and can predict future tokens to generate speech continuations. Inputs are cochlear token IDs; use it with a WavCoch tokenizer for audio -> tokens.


Installation

pip install -U torch torchaudio transformers

This model uses custom code; when loading from Hugging Face, pass trust_remote_code=True.


Use case 1) Get hidden state embeddings for an audio waveform

import torch, torchaudio
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) Load the WavCoch tokenizer (audio -> token IDs)
quantizer = AutoModel.from_pretrained(
    "TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()

# 2) Load the AuriStream LM (tokens -> hidden states / next-token prediction)
lm = AutoModel.from_pretrained(
    "TuKoResearch/AuriStream1B_librilight_ckpt500k", trust_remote_code=True
).to(device).eval()

# 3) Read an audio file (mono, 16 kHz recommended)
wav, sr = torchaudio.load("sample.wav")

if wav.size(0) > 1:  # stereo -> mono
    wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
    wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
    sr = 16_000

# 4) Quantize the audio to obtain cochlear token IDs
with torch.no_grad():
    # The quantizer forward method expects (B, 1, T); returns (B, L)
    token_ids = quantizer(wav.unsqueeze(0).to(device))['input_ids']  # (1, L)

# 5) Forward pass to obtain hidden states
with torch.no_grad():
    out = lm(token_ids, output_hidden_states=True)
    last_layer = out["hidden_states"][-1]   # (1, T, D)
    last_layer_mean = last_layer.mean(dim=1)  # time mean-pool -> (1, D)

print("Mean-pooled embedding shape:", last_layer_mean.shape)

Notes

  • output_hidden_states=True returns all layers.
  • For phoneme/word segments, slice the time axis before pooling.

Use case 2) Generate a speech continuation (cochlear token prediction)

import torch, torchaudio
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# WavCoch tokenizer (audio -> tokens)
quantizer = AutoModel.from_pretrained(
    "TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()

# AuriStream LM (tokens -> next tokens)
lm = AutoModel.from_pretrained(
    "TuKoResearch/AuriStream1B_librilight_ckpt500k", trust_remote_code=True
).to(device).eval()

# Load and prep a short prompt (e.g., 3s of audio at 16 kHz)
prompt_seconds = 3
wav, sr = torchaudio.load("prompt.wav")
if wav.size(0) > 1:
    wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
    wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
    sr = 16_000
# Slice using an integer number of samples
n_samples = int(round(sr * prompt_seconds))
wav = wav[:, :n_samples]

# Quantize the prompt audio to get token IDs
with torch.no_grad():
    prompt_tokens = quantizer(wav.unsqueeze(0).to(device))['input_ids']  # (1, L)

# Decide how many future tokens to generate ("roll-out")
tokens_per_sec = prompt_tokens.size(1) / float(prompt_seconds)
rollout_seconds = 2
rollout_steps = int(round(tokens_per_sec * rollout_seconds)) # K

# Generate future tokens
with torch.no_grad():
    # returns (pred_tokens, pred_logits); temperature/top_k/top_p/seed optional
    pred_tokens, _ = lm.generate(
        prompt_tokens, rollout_steps, temp=0.7, top_k=50, top_p=0.95, seed=0
    )
    full_tokens = torch.cat([prompt_tokens, pred_tokens], dim=1)  # (1, L+K)

Architecture overview

Schematic of the WavCoch tokenizer (panel A) and the AuriStream model (panel B).

Citation

If you use this model, please cite:

@inproceedings{tuckute2025cochleartokens,
  title     = {Representing Speech Through Autoregressive Prediction of Cochlear Tokens},
  author    = {Greta Tuckute and Klemen Kotar and Evelina Fedorenko and Daniel Yamins},
  booktitle = {Interspeech 2025},
  year      = {2025},
  pages     = {2180--2184},
  doi       = {10.21437/Interspeech.2025-2044},
  issn      = {2958-1796}
}
Downloads last month
168
Safetensors
Model size
1.38B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support