stevhliu HF Staff commited on
Commit
3a2b2c9
·
1 Parent(s): ba50986

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +21 -0
pipeline.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers import DiffusionPipeline
4
+
5
+
6
+ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
7
+ def __init__(self, unet, scheduler):
8
+ super().__init__()
9
+
10
+ self.register_modules(unet=unet, scheduler=scheduler)
11
+
12
+ def __call__(self):
13
+ image = torch.randn(
14
+ (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
15
+ )
16
+ timestep = 1
17
+
18
+ model_output = self.unet(image, timestep).sample
19
+ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
20
+
21
+ return scheduler_output