HyperNoise Sana Sprint 0.6B
HyperNoise enables reward-based fine-tuning of distilled Diffusion models to align, e.g. with human-preference reward models without expensive test-time optimization.
Model Description
This model is a LoRA adapter that trained as a Noise Hypernetworks as described in Noise Hypernetworks: Amortizing Test-Time Compute in Diffusion Models. Instead of modifying the base generator parameters, HyperNoise learns to predict an optimal initial noise distribution that steers the frozen SANA-Sprint model towards higher-quality, preference-aligned outputs. The model was trained on an ensemble of human-preference reward models.
Key Features
- Amortized optimization: Amortizes expensive test-time noise optimization into a single forward pass
- Theoretical foundation: Tractable loss to learn the reward-tilted distribution for distilled Diffusion models
- Multi-step generalization: Trained on 1-step generation but improves quality across 1-32 inference steps
Github
For more information, check out the github and the project page.
Usage example
import torch
from diffusers import SanaSprintPipeline
import peft
from peft.tuners.lora.layer import Linear as LoraLinear
import types
prompt = "A smiling slice of pizza doing yoga on a mountain top."
adapter_name = "hypernoise_adapter"
device = torch.device("cuda")
pipe = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
torch_dtype=torch.bfloat16,
).to(device, torch.bfloat16)
pipe.transformer = peft.PeftModel.from_pretrained(
pipe.transformer,
"lucaeyring/HyperNoise_Sana_Sprint_0.6B",
adapter_name=adapter_name,
dtype=torch.bfloat16,
).to(device, torch.bfloat16)
def scaled_base_lora_forward(self, x, *args, **kwargs):
if self.disable_adapters: return self.base_layer(x, *args, **kwargs)
return self.lora_B[adapter_name](self.lora_A[adapter_name](x)) * self.scaling[adapter_name]
for name, module in pipe.transformer.base_model.model.named_modules():
if name == "proj_out" and isinstance(module, LoraLinear):
module.forward = types.MethodType(scaled_base_lora_forward, module); break
with torch.inference_mode():
prompt_embeds, prompt_attention_mask = pipe.encode_prompt([prompt], device=device)
init_latents = torch.randn([1, 32, 32, 32], device=device, dtype=torch.bfloat16)
pipe.transformer.enable_adapter_layers()
modulated_latents = pipe.transformer(
hidden_states=init_latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
guidance=torch.tensor([4.5], device=device, dtype=torch.bfloat16) * 0.1,
timestep=torch.tensor([1.0], device=device, dtype=torch.bfloat16),
).sample + init_latents
pipe.transformer.disable_adapter_layers()
hypernoise_image = pipe(
latents=modulated_latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
num_inference_steps=4,
).images[0]
hypernoise_image.save("hypernoise-sana.png")
Model tree for lucaeyring/HyperNoise_Sana_Sprint_0.6B
Unable to build the model tree, the base model loops to the model itself. Learn more.