ApoorvBrooklyn commited on
Commit
b733626
·
verified ·
1 Parent(s): 30492b8

Delete files attention.py clip.py ddpm.py decoder.py demo.py diffusion.py encoder.py model_converter.py model_loader.py pipeline.py with huggingface_hub

Browse files
Files changed (10) hide show
  1. attention.py +0 -122
  2. clip.py +0 -96
  3. ddpm.py +0 -123
  4. decoder.py +0 -177
  5. demo.py +0 -67
  6. diffusion.py +0 -349
  7. encoder.py +0 -103
  8. model_converter.py +0 -0
  9. model_loader.py +0 -28
  10. pipeline.py +0 -170
attention.py DELETED
@@ -1,122 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- import math
5
-
6
- class SelfAttention(nn.Module):
7
- def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
8
- super().__init__()
9
- # This combines the Wq, Wk and Wv matrices into one matrix
10
- self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
11
- # This one represents the Wo matrix
12
- self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
13
- self.n_heads = n_heads
14
- self.d_head = d_embed // n_heads
15
-
16
- def forward(self, x, causal_mask=False):
17
- # x: # (Batch_Size, Seq_Len, Dim)
18
-
19
- # (Batch_Size, Seq_Len, Dim)
20
- input_shape = x.shape
21
-
22
- # (Batch_Size, Seq_Len, Dim)
23
- batch_size, sequence_length, d_embed = input_shape
24
-
25
- # (Batch_Size, Seq_Len, H, Dim / H)
26
- interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
27
-
28
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
29
- q, k, v = self.in_proj(x).chunk(3, dim=-1)
30
-
31
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
32
- q = q.view(interim_shape).transpose(1, 2)
33
- k = k.view(interim_shape).transpose(1, 2)
34
- v = v.view(interim_shape).transpose(1, 2)
35
-
36
- # (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
37
- weight = q @ k.transpose(-1, -2)
38
-
39
- if causal_mask:
40
- # Mask where the upper triangle (above the principal diagonal) is 1
41
- mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
42
- # Fill the upper triangle with -inf
43
- weight.masked_fill_(mask, -torch.inf)
44
-
45
- # Divide by d_k (Dim / H).
46
- # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
47
- weight /= math.sqrt(self.d_head)
48
-
49
- # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
50
- weight = F.softmax(weight, dim=-1)
51
-
52
- # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
53
- output = weight @ v
54
-
55
- # (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
56
- output = output.transpose(1, 2)
57
-
58
- # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
59
- output = output.reshape(input_shape)
60
-
61
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
62
- output = self.out_proj(output)
63
-
64
- # (Batch_Size, Seq_Len, Dim)
65
- return output
66
-
67
- class CrossAttention(nn.Module):
68
- def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
69
- super().__init__()
70
- self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
71
- self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
72
- self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
73
- self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
74
- self.n_heads = n_heads
75
- self.d_head = d_embed // n_heads
76
-
77
- def forward(self, x, y):
78
- # x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
79
- # y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
80
-
81
- input_shape = x.shape
82
- batch_size, sequence_length, d_embed = input_shape
83
- # Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
84
- interim_shape = (batch_size, -1, self.n_heads, self.d_head)
85
-
86
- # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
87
- q = self.q_proj(x)
88
- # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
89
- k = self.k_proj(y)
90
- # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
91
- v = self.v_proj(y)
92
-
93
- # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
94
- q = q.view(interim_shape).transpose(1, 2)
95
- # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
96
- k = k.view(interim_shape).transpose(1, 2)
97
- # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
98
- v = v.view(interim_shape).transpose(1, 2)
99
-
100
- # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
101
- weight = q @ k.transpose(-1, -2)
102
-
103
- # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
104
- weight /= math.sqrt(self.d_head)
105
-
106
- # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
107
- weight = F.softmax(weight, dim=-1)
108
-
109
- # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
110
- output = weight @ v
111
-
112
- # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
113
- output = output.transpose(1, 2).contiguous()
114
-
115
- # (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
116
- output = output.view(input_shape)
117
-
118
- # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
119
- output = self.out_proj(output)
120
-
121
- # (Batch_Size, Seq_Len_Q, Dim_Q)
122
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clip.py DELETED
@@ -1,96 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from attention import SelfAttention
5
-
6
- class CLIPEmbedding(nn.Module):
7
- def __init__(self, n_vocab: int, n_embd: int, n_token: int):
8
- super().__init__()
9
-
10
- self.token_embedding = nn.Embedding(n_vocab, n_embd)
11
- # A learnable weight matrix encodes the position information for each token
12
- self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
13
-
14
- def forward(self, tokens):
15
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
16
- x = self.token_embedding(tokens)
17
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
18
- x += self.position_embedding
19
-
20
- return x
21
-
22
- class CLIPLayer(nn.Module):
23
- def __init__(self, n_head: int, n_embd: int):
24
- super().__init__()
25
-
26
- # Pre-attention norm
27
- self.layernorm_1 = nn.LayerNorm(n_embd)
28
- # Self attention
29
- self.attention = SelfAttention(n_head, n_embd)
30
- # Pre-FNN norm
31
- self.layernorm_2 = nn.LayerNorm(n_embd)
32
- # Feedforward layer
33
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
34
- self.linear_2 = nn.Linear(4 * n_embd, n_embd)
35
-
36
- def forward(self, x):
37
- # (Batch_Size, Seq_Len, Dim)
38
- residue = x
39
-
40
- ### SELF ATTENTION ###
41
-
42
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
43
- x = self.layernorm_1(x)
44
-
45
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
46
- x = self.attention(x, causal_mask=True)
47
-
48
- # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
49
- x += residue
50
-
51
- ### FEEDFORWARD LAYER ###
52
- # Apply a feedforward layer where the hidden dimension is 4 times the embedding dimension.
53
-
54
- residue = x
55
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
56
- x = self.layernorm_2(x)
57
-
58
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
59
- x = self.linear_1(x)
60
-
61
- # (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
62
- x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
63
-
64
- # (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, Dim)
65
- x = self.linear_2(x)
66
-
67
- # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
68
- x += residue
69
-
70
- return x
71
-
72
- class CLIP(nn.Module):
73
- def __init__(self):
74
- super().__init__()
75
- self.embedding = CLIPEmbedding(49408, 768, 77)
76
-
77
- self.layers = nn.ModuleList([
78
- CLIPLayer(12, 768) for i in range(12)
79
- ])
80
-
81
- self.layernorm = nn.LayerNorm(768)
82
-
83
- def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
84
- tokens = tokens.type(torch.long)
85
-
86
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
87
- state = self.embedding(tokens)
88
-
89
- # Apply encoder layers similar to the Transformer's encoder.
90
- for layer in self.layers:
91
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
92
- state = layer(state)
93
- # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
94
- output = self.layernorm(state)
95
-
96
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddpm.py DELETED
@@ -1,123 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
- class DDPMSampler:
5
-
6
- def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
7
- # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
8
- # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
9
- self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
10
- self.alphas = 1.0 - self.betas
11
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
12
- self.one = torch.tensor(1.0)
13
-
14
- self.generator = generator
15
-
16
- self.num_train_timesteps = num_training_steps
17
- self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
18
-
19
- def set_inference_timesteps(self, num_inference_steps=50):
20
- self.num_inference_steps = num_inference_steps
21
- step_ratio = self.num_train_timesteps // self.num_inference_steps
22
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
23
- self.timesteps = torch.from_numpy(timesteps)
24
-
25
- def _get_previous_timestep(self, timestep: int) -> int:
26
- prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
27
- return prev_t
28
-
29
- def _get_variance(self, timestep: int) -> torch.Tensor:
30
- prev_t = self._get_previous_timestep(timestep)
31
-
32
- alpha_prod_t = self.alphas_cumprod[timestep]
33
- alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
34
- current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
35
-
36
- # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
37
- # and sample from it to get previous sample
38
- # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
39
- variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
40
-
41
- # we always take the log of variance, so clamp it to ensure it's not 0
42
- variance = torch.clamp(variance, min=1e-20)
43
-
44
- return variance
45
-
46
- def set_strength(self, strength=1):
47
- """
48
- Set how much noise to add to the input image.
49
- More noise (strength ~ 1) means that the output will be further from the input image.
50
- Less noise (strength ~ 0) means that the output will be closer to the input image.
51
- """
52
- # start_step is the number of noise levels to skip
53
- start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
54
- self.timesteps = self.timesteps[start_step:]
55
- self.start_step = start_step
56
-
57
- def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
58
- t = timestep
59
- prev_t = self._get_previous_timestep(t)
60
-
61
- # 1. compute alphas, betas
62
- alpha_prod_t = self.alphas_cumprod[t]
63
- alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
64
- beta_prod_t = 1 - alpha_prod_t
65
- beta_prod_t_prev = 1 - alpha_prod_t_prev
66
- current_alpha_t = alpha_prod_t / alpha_prod_t_prev
67
- current_beta_t = 1 - current_alpha_t
68
-
69
- # 2. compute predicted original sample from predicted noise also called
70
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
71
- pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
72
-
73
- # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
74
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
75
- pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
76
- current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
77
-
78
- # 5. Compute predicted previous sample µ_t
79
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
80
- pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
81
-
82
- # 6. Add noise
83
- variance = 0
84
- if t > 0:
85
- device = model_output.device
86
- noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
87
- # Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
88
- variance = (self._get_variance(t) ** 0.5) * noise
89
-
90
- # sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
91
- # the variable "variance" is already multiplied by the noise N(0, 1)
92
- pred_prev_sample = pred_prev_sample + variance
93
-
94
- return pred_prev_sample
95
-
96
- def add_noise(
97
- self,
98
- original_samples: torch.FloatTensor,
99
- timesteps: torch.IntTensor,
100
- ) -> torch.FloatTensor:
101
- alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
102
- timesteps = timesteps.to(original_samples.device)
103
-
104
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
105
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
106
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
107
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
108
-
109
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
110
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
111
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
112
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
113
-
114
- # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
115
- # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
116
- # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
117
- noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
118
- noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
119
- return noisy_samples
120
-
121
-
122
-
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
decoder.py DELETED
@@ -1,177 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from attention import SelfAttention
5
-
6
- class VAE_AttentionBlock(nn.Module):
7
- def __init__(self, channels):
8
- super().__init__()
9
- self.groupnorm = nn.GroupNorm(32, channels)
10
- self.attention = SelfAttention(1, channels)
11
-
12
- def forward(self, x):
13
- # x: (Batch_Size, Features, Height, Width)
14
-
15
- residue = x
16
-
17
- # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
18
- x = self.groupnorm(x)
19
-
20
- n, c, h, w = x.shape
21
-
22
- # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
23
- x = x.view((n, c, h * w))
24
-
25
- # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features). Each pixel becomes a feature of size "Features", the sequence length is "Height * Width".
26
- x = x.transpose(-1, -2)
27
-
28
- # Perform self-attention WITHOUT mask
29
- # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
30
- x = self.attention(x)
31
-
32
- # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
33
- x = x.transpose(-1, -2)
34
-
35
- # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
36
- x = x.view((n, c, h, w))
37
-
38
- # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
39
- x += residue
40
-
41
- # (Batch_Size, Features, Height, Width)
42
- return x
43
-
44
- class VAE_ResidualBlock(nn.Module):
45
- def __init__(self, in_channels, out_channels):
46
- super().__init__()
47
- self.groupnorm_1 = nn.GroupNorm(32, in_channels)
48
- self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
49
-
50
- self.groupnorm_2 = nn.GroupNorm(32, out_channels)
51
- self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
52
-
53
- if in_channels == out_channels:
54
- self.residual_layer = nn.Identity()
55
- else:
56
- self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
57
-
58
- def forward(self, x):
59
- # x: (Batch_Size, In_Channels, Height, Width)
60
-
61
- residue = x
62
-
63
- # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
64
- x = self.groupnorm_1(x)
65
-
66
- # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
67
- x = F.silu(x)
68
-
69
- # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
70
- x = self.conv_1(x)
71
-
72
- # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
73
- x = self.groupnorm_2(x)
74
-
75
- # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
76
- x = F.silu(x)
77
-
78
- # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
79
- x = self.conv_2(x)
80
-
81
- # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
82
- return x + self.residual_layer(residue)
83
-
84
- class VAE_Decoder(nn.Sequential):
85
- def __init__(self):
86
- super().__init__(
87
- # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
88
- nn.Conv2d(4, 4, kernel_size=1, padding=0),
89
-
90
- # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
91
- nn.Conv2d(4, 512, kernel_size=3, padding=1),
92
-
93
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
94
- VAE_ResidualBlock(512, 512),
95
-
96
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
97
- VAE_AttentionBlock(512),
98
-
99
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
100
- VAE_ResidualBlock(512, 512),
101
-
102
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
103
- VAE_ResidualBlock(512, 512),
104
-
105
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
106
- VAE_ResidualBlock(512, 512),
107
-
108
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
109
- VAE_ResidualBlock(512, 512),
110
-
111
- # Repeats the rows and columns of the data by scale_factor (like when you resize an image by doubling its size).
112
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
113
- nn.Upsample(scale_factor=2),
114
-
115
- # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
116
- nn.Conv2d(512, 512, kernel_size=3, padding=1),
117
-
118
- # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
119
- VAE_ResidualBlock(512, 512),
120
-
121
- # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
122
- VAE_ResidualBlock(512, 512),
123
-
124
- # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
125
- VAE_ResidualBlock(512, 512),
126
-
127
- # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
128
- nn.Upsample(scale_factor=2),
129
-
130
- # (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 512, Height / 2, Width / 2)
131
- nn.Conv2d(512, 512, kernel_size=3, padding=1),
132
-
133
- # (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
134
- VAE_ResidualBlock(512, 256),
135
-
136
- # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
137
- VAE_ResidualBlock(256, 256),
138
-
139
- # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
140
- VAE_ResidualBlock(256, 256),
141
-
142
- # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height, Width)
143
- nn.Upsample(scale_factor=2),
144
-
145
- # (Batch_Size, 256, Height, Width) -> (Batch_Size, 256, Height, Width)
146
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
147
-
148
- # (Batch_Size, 256, Height, Width) -> (Batch_Size, 128, Height, Width)
149
- VAE_ResidualBlock(256, 128),
150
-
151
- # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
152
- VAE_ResidualBlock(128, 128),
153
-
154
- # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
155
- VAE_ResidualBlock(128, 128),
156
-
157
- # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
158
- nn.GroupNorm(32, 128),
159
-
160
- # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
161
- nn.SiLU(),
162
-
163
- # (Batch_Size, 128, Height, Width) -> (Batch_Size, 3, Height, Width)
164
- nn.Conv2d(128, 3, kernel_size=3, padding=1),
165
- )
166
-
167
- def forward(self, x):
168
- # x: (Batch_Size, 4, Height / 8, Width / 8)
169
-
170
- # Remove the scaling added by the Encoder.
171
- x /= 0.18215
172
-
173
- for module in self:
174
- x = module(x)
175
-
176
- # (Batch_Size, 3, Height, Width)
177
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py DELETED
@@ -1,67 +0,0 @@
1
- import model_loader
2
- import pipeline
3
- from PIL import Image
4
- from pathlib import Path
5
- from transformers import CLIPTokenizer
6
- import torch
7
-
8
- DEVICE = "cpu"
9
-
10
- ALLOW_CUDA = False
11
- ALLOW_MPS = False
12
-
13
- if torch.cuda.is_available() and ALLOW_CUDA:
14
- DEVICE = "cuda"
15
- elif (torch.has_mps or torch.backends.mps.is_available()) and ALLOW_MPS:
16
- DEVICE = "mps"
17
- print(f"Using device: {DEVICE}")
18
-
19
- tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
20
- model_file = "../data/v1-5-pruned-emaonly.ckpt"
21
- models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
22
-
23
- ## TEXT TO IMAGE
24
-
25
- # prompt = "A dog with sunglasses, wearing comfy hat, looking at camera, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution."
26
- prompt = "A boy playing football with his teammates."
27
- uncond_prompt = "" # Also known as negative prompt
28
- do_cfg = True
29
- cfg_scale = 8 # min: 1, max: 14
30
-
31
- ## IMAGE TO IMAGE
32
-
33
- input_image = None
34
- # Comment to disable image to image
35
- image_path = "../images/dog.jpg"
36
- # input_image = Image.open(image_path)
37
- # Higher values means more noise will be added to the input image, so the result will further from the input image.
38
- # Lower values means less noise is added to the input image, so output will be closer to the input image.
39
- strength = 0.9
40
-
41
- ## SAMPLER
42
-
43
- sampler = "ddpm"
44
- num_inference_steps = 50
45
- seed = 42
46
-
47
- output_image = pipeline.generate(
48
- prompt=prompt,
49
- uncond_prompt=uncond_prompt,
50
- input_image=input_image,
51
- strength=strength,
52
- do_cfg=do_cfg,
53
- cfg_scale=cfg_scale,
54
- sampler_name=sampler,
55
- n_inference_steps=num_inference_steps,
56
- seed=seed,
57
- models=models,
58
- device=DEVICE,
59
- idle_device="cpu",
60
- tokenizer=tokenizer,
61
- )
62
-
63
- # Combine the input image and the output image into a single image.
64
- Image.fromarray(output_image)
65
- result_img = Image.fromarray(output_image)
66
- result_img.save("output.png")
67
- print("Saved output.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusion.py DELETED
@@ -1,349 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from attention import SelfAttention, CrossAttention
5
-
6
- class TimeEmbedding(nn.Module):
7
- def __init__(self, n_embd):
8
- super().__init__()
9
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
10
- self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
11
-
12
- def forward(self, x):
13
- # x: (1, 320)
14
-
15
- # (1, 320) -> (1, 1280)
16
- x = self.linear_1(x)
17
-
18
- # (1, 1280) -> (1, 1280)
19
- x = F.silu(x)
20
-
21
- # (1, 1280) -> (1, 1280)
22
- x = self.linear_2(x)
23
-
24
- return x
25
-
26
- class UNET_ResidualBlock(nn.Module):
27
- def __init__(self, in_channels, out_channels, n_time=1280):
28
- super().__init__()
29
- self.groupnorm_feature = nn.GroupNorm(32, in_channels)
30
- self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
31
- self.linear_time = nn.Linear(n_time, out_channels)
32
-
33
- self.groupnorm_merged = nn.GroupNorm(32, out_channels)
34
- self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
35
-
36
- if in_channels == out_channels:
37
- self.residual_layer = nn.Identity()
38
- else:
39
- self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
40
-
41
- def forward(self, feature, time):
42
- # feature: (Batch_Size, In_Channels, Height, Width)
43
- # time: (1, 1280)
44
-
45
- residue = feature
46
-
47
- # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
48
- feature = self.groupnorm_feature(feature)
49
-
50
- # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
51
- feature = F.silu(feature)
52
-
53
- # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
54
- feature = self.conv_feature(feature)
55
-
56
- # (1, 1280) -> (1, 1280)
57
- time = F.silu(time)
58
-
59
- # (1, 1280) -> (1, Out_Channels)
60
- time = self.linear_time(time)
61
-
62
- # Add width and height dimension to time.
63
- # (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)
64
- merged = feature + time.unsqueeze(-1).unsqueeze(-1)
65
-
66
- # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
67
- merged = self.groupnorm_merged(merged)
68
-
69
- # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
70
- merged = F.silu(merged)
71
-
72
- # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
73
- merged = self.conv_merged(merged)
74
-
75
- # (Batch_Size, Out_Channels, Height, Width) + (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
76
- return merged + self.residual_layer(residue)
77
-
78
- class UNET_AttentionBlock(nn.Module):
79
- def __init__(self, n_head: int, n_embd: int, d_context=768):
80
- super().__init__()
81
- channels = n_head * n_embd
82
-
83
- self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
84
- self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
85
-
86
- self.layernorm_1 = nn.LayerNorm(channels)
87
- self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
88
- self.layernorm_2 = nn.LayerNorm(channels)
89
- self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
90
- self.layernorm_3 = nn.LayerNorm(channels)
91
- self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
92
- self.linear_geglu_2 = nn.Linear(4 * channels, channels)
93
-
94
- self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
95
-
96
- def forward(self, x, context):
97
- # x: (Batch_Size, Features, Height, Width)
98
- # context: (Batch_Size, Seq_Len, Dim)
99
-
100
- residue_long = x
101
-
102
- # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
103
- x = self.groupnorm(x)
104
-
105
- # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
106
- x = self.conv_input(x)
107
-
108
- n, c, h, w = x.shape
109
-
110
- # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
111
- x = x.view((n, c, h * w))
112
-
113
- # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)
114
- x = x.transpose(-1, -2)
115
-
116
- # Normalization + Self-Attention with skip connection
117
-
118
- # (Batch_Size, Height * Width, Features)
119
- residue_short = x
120
-
121
- # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
122
- x = self.layernorm_1(x)
123
-
124
- # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
125
- x = self.attention_1(x)
126
-
127
- # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
128
- x += residue_short
129
-
130
- # (Batch_Size, Height * Width, Features)
131
- residue_short = x
132
-
133
- # Normalization + Cross-Attention with skip connection
134
-
135
- # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
136
- x = self.layernorm_2(x)
137
-
138
- # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
139
- x = self.attention_2(x, context)
140
-
141
- # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
142
- x += residue_short
143
-
144
- # (Batch_Size, Height * Width, Features)
145
- residue_short = x
146
-
147
- # Normalization + FFN with GeGLU and skip connection
148
-
149
- # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
150
- x = self.layernorm_3(x)
151
-
152
- # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
153
- # (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4)
154
- x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
155
-
156
- # Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4)
157
- x = x * F.gelu(gate)
158
-
159
- # (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features)
160
- x = self.linear_geglu_2(x)
161
-
162
- # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
163
- x += residue_short
164
-
165
- # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
166
- x = x.transpose(-1, -2)
167
-
168
- # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
169
- x = x.view((n, c, h, w))
170
-
171
- # Final skip connection between initial input and output of the block
172
- # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
173
- return self.conv_output(x) + residue_long
174
-
175
- class Upsample(nn.Module):
176
- def __init__(self, channels):
177
- super().__init__()
178
- self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
179
-
180
- def forward(self, x):
181
- # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)
182
- x = F.interpolate(x, scale_factor=2, mode='nearest')
183
- return self.conv(x)
184
-
185
- class SwitchSequential(nn.Sequential):
186
- def forward(self, x, context, time):
187
- for layer in self:
188
- if isinstance(layer, UNET_AttentionBlock):
189
- x = layer(x, context)
190
- elif isinstance(layer, UNET_ResidualBlock):
191
- x = layer(x, time)
192
- else:
193
- x = layer(x)
194
- return x
195
-
196
- class UNET(nn.Module):
197
- def __init__(self):
198
- super().__init__()
199
- self.encoders = nn.ModuleList([
200
- # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
201
- SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
202
-
203
- # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
204
- SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
205
-
206
- # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
207
- SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
208
-
209
- # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
210
- SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
211
-
212
- # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
213
- SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
214
-
215
- # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
216
- SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
217
-
218
- # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
219
- SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
220
-
221
- # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
222
- SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
223
-
224
- # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
225
- SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
226
-
227
- # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
228
- SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
229
-
230
- # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
231
- SwitchSequential(UNET_ResidualBlock(1280, 1280)),
232
-
233
- # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
234
- SwitchSequential(UNET_ResidualBlock(1280, 1280)),
235
- ])
236
-
237
- self.bottleneck = SwitchSequential(
238
- # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
239
- UNET_ResidualBlock(1280, 1280),
240
-
241
- # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
242
- UNET_AttentionBlock(8, 160),
243
-
244
- # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
245
- UNET_ResidualBlock(1280, 1280),
246
- )
247
-
248
- self.decoders = nn.ModuleList([
249
- # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
250
- SwitchSequential(UNET_ResidualBlock(2560, 1280)),
251
-
252
- # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
253
- SwitchSequential(UNET_ResidualBlock(2560, 1280)),
254
-
255
- # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
256
- SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
257
-
258
- # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
259
- SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
260
-
261
- # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
262
- SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
263
-
264
- # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
265
- SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
266
-
267
- # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
268
- SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
269
-
270
- # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
271
- SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
272
-
273
- # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
274
- SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
275
-
276
- # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
277
- SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
278
-
279
- # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
280
- SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
281
-
282
- # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
283
- SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
284
- ])
285
-
286
- def forward(self, x, context, time):
287
- # x: (Batch_Size, 4, Height / 8, Width / 8)
288
- # context: (Batch_Size, Seq_Len, Dim)
289
- # time: (1, 1280)
290
-
291
- skip_connections = []
292
- for layers in self.encoders:
293
- x = layers(x, context, time)
294
- skip_connections.append(x)
295
-
296
- x = self.bottleneck(x, context, time)
297
-
298
- for layers in self.decoders:
299
- # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
300
- x = torch.cat((x, skip_connections.pop()), dim=1)
301
- x = layers(x, context, time)
302
-
303
- return x
304
-
305
-
306
- class UNET_OutputLayer(nn.Module):
307
- def __init__(self, in_channels, out_channels):
308
- super().__init__()
309
- self.groupnorm = nn.GroupNorm(32, in_channels)
310
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
311
-
312
- def forward(self, x):
313
- # x: (Batch_Size, 320, Height / 8, Width / 8)
314
-
315
- # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
316
- x = self.groupnorm(x)
317
-
318
- # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
319
- x = F.silu(x)
320
-
321
- # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
322
- x = self.conv(x)
323
-
324
- # (Batch_Size, 4, Height / 8, Width / 8)
325
- return x
326
-
327
- class Diffusion(nn.Module):
328
- def __init__(self):
329
- super().__init__()
330
- self.time_embedding = TimeEmbedding(320)
331
- self.unet = UNET()
332
- self.final = UNET_OutputLayer(320, 4)
333
-
334
- def forward(self, latent, context, time):
335
- # latent: (Batch_Size, 4, Height / 8, Width / 8)
336
- # context: (Batch_Size, Seq_Len, Dim)
337
- # time: (1, 320)
338
-
339
- # (1, 320) -> (1, 1280)
340
- time = self.time_embedding(time)
341
-
342
- # (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)
343
- output = self.unet(latent, context, time)
344
-
345
- # (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)
346
- output = self.final(output)
347
-
348
- # (Batch, 4, Height / 8, Width / 8)
349
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder.py DELETED
@@ -1,103 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from decoder import VAE_AttentionBlock, VAE_ResidualBlock
5
-
6
- class VAE_Encoder(nn.Sequential):
7
- def __init__(self):
8
- super().__init__(
9
- # (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
10
- nn.Conv2d(3, 128, kernel_size=3, padding=1),
11
-
12
- # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
13
- VAE_ResidualBlock(128, 128),
14
-
15
- # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
16
- VAE_ResidualBlock(128, 128),
17
-
18
- # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height / 2, Width / 2)
19
- nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
20
-
21
- # (Batch_Size, 128, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
22
- VAE_ResidualBlock(128, 256),
23
-
24
- # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
25
- VAE_ResidualBlock(256, 256),
26
-
27
- # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 4, Width / 4)
28
- nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
29
-
30
- # (Batch_Size, 256, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
31
- VAE_ResidualBlock(256, 512),
32
-
33
- # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
34
- VAE_ResidualBlock(512, 512),
35
-
36
- # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 8, Width / 8)
37
- nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
38
-
39
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
40
- VAE_ResidualBlock(512, 512),
41
-
42
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
43
- VAE_ResidualBlock(512, 512),
44
-
45
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
46
- VAE_ResidualBlock(512, 512),
47
-
48
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
49
- VAE_AttentionBlock(512),
50
-
51
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
52
- VAE_ResidualBlock(512, 512),
53
-
54
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
55
- nn.GroupNorm(32, 512),
56
-
57
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
58
- nn.SiLU(),
59
-
60
- # Because the padding=1, it means the width and height will increase by 2
61
- # Out_Height = In_Height + Padding_Top + Padding_Bottom
62
- # Out_Width = In_Width + Padding_Left + Padding_Right
63
- # Since padding = 1 means Padding_Top = Padding_Bottom = Padding_Left = Padding_Right = 1,
64
- # Since the Out_Width = In_Width + 2 (same for Out_Height), it will compensate for the Kernel size of 3
65
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8).
66
- nn.Conv2d(512, 8, kernel_size=3, padding=1),
67
-
68
- # (Batch_Size, 8, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8)
69
- nn.Conv2d(8, 8, kernel_size=1, padding=0),
70
- )
71
-
72
- def forward(self, x, noise):
73
- # x: (Batch_Size, Channel, Height, Width)
74
- # noise: (Batch_Size, 4, Height / 8, Width / 8)
75
-
76
- for module in self:
77
-
78
- if getattr(module, 'stride', None) == (2, 2): # Padding at downsampling should be asymmetric (see #8)
79
- # Pad: (Padding_Left, Padding_Right, Padding_Top, Padding_Bottom).
80
- # Pad with zeros on the right and bottom.
81
- # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Channel, Height + Padding_Top + Padding_Bottom, Width + Padding_Left + Padding_Right) = (Batch_Size, Channel, Height + 1, Width + 1)
82
- x = F.pad(x, (0, 1, 0, 1))
83
-
84
- x = module(x)
85
- # (Batch_Size, 8, Height / 8, Width / 8) -> two tensors of shape (Batch_Size, 4, Height / 8, Width / 8)
86
- mean, log_variance = torch.chunk(x, 2, dim=1)
87
- # Clamp the log variance between -30 and 20, so that the variance is between (circa) 1e-14 and 1e8.
88
- # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
89
- log_variance = torch.clamp(log_variance, -30, 20)
90
- # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
91
- variance = log_variance.exp()
92
- # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
93
- stdev = variance.sqrt()
94
-
95
- # Transform N(0, 1) -> N(mean, stdev)
96
- # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
97
- x = mean + stdev * noise
98
-
99
- # Scale by a constant
100
- # Constant taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L17C1-L17C1
101
- x *= 0.18215
102
-
103
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_converter.py DELETED
The diff for this file is too large to render. See raw diff
 
model_loader.py DELETED
@@ -1,28 +0,0 @@
1
- from clip import CLIP
2
- from encoder import VAE_Encoder
3
- from decoder import VAE_Decoder
4
- from diffusion import Diffusion
5
-
6
- import model_converter
7
-
8
- def preload_models_from_standard_weights(ckpt_path, device):
9
- state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
10
-
11
- encoder = VAE_Encoder().to(device)
12
- encoder.load_state_dict(state_dict['encoder'], strict=True)
13
-
14
- decoder = VAE_Decoder().to(device)
15
- decoder.load_state_dict(state_dict['decoder'], strict=True)
16
-
17
- diffusion = Diffusion().to(device)
18
- diffusion.load_state_dict(state_dict['diffusion'], strict=True)
19
-
20
- clip = CLIP().to(device)
21
- clip.load_state_dict(state_dict['clip'], strict=True)
22
-
23
- return {
24
- 'clip': clip,
25
- 'encoder': encoder,
26
- 'decoder': decoder,
27
- 'diffusion': diffusion,
28
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline.py DELETED
@@ -1,170 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from tqdm import tqdm
4
- from ddpm import DDPMSampler
5
-
6
- WIDTH = 512
7
- HEIGHT = 512
8
- LATENTS_WIDTH = WIDTH // 8
9
- LATENTS_HEIGHT = HEIGHT // 8
10
-
11
- def generate(
12
- prompt,
13
- uncond_prompt=None,
14
- input_image=None,
15
- strength=0.8,
16
- do_cfg=True,
17
- cfg_scale=7.5,
18
- sampler_name="ddpm",
19
- n_inference_steps=50,
20
- models={},
21
- seed=None,
22
- device=None,
23
- idle_device=None,
24
- tokenizer=None,
25
- ):
26
- with torch.no_grad():
27
- if not 0 < strength <= 1:
28
- raise ValueError("strength must be between 0 and 1")
29
-
30
- if idle_device:
31
- to_idle = lambda x: x.to(idle_device)
32
- else:
33
- to_idle = lambda x: x
34
-
35
- # Initialize random number generator according to the seed specified
36
- generator = torch.Generator(device=device)
37
- if seed is None:
38
- generator.seed()
39
- else:
40
- generator.manual_seed(seed)
41
-
42
- clip = models["clip"]
43
- clip.to(device)
44
-
45
- if do_cfg:
46
- # Convert into a list of length Seq_Len=77
47
- cond_tokens = tokenizer.batch_encode_plus(
48
- [prompt], padding="max_length", max_length=77
49
- ).input_ids
50
- # (Batch_Size, Seq_Len)
51
- cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
52
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
53
- cond_context = clip(cond_tokens)
54
- # Convert into a list of length Seq_Len=77
55
- uncond_tokens = tokenizer.batch_encode_plus(
56
- [uncond_prompt], padding="max_length", max_length=77
57
- ).input_ids
58
- # (Batch_Size, Seq_Len)
59
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
60
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
61
- uncond_context = clip(uncond_tokens)
62
- # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
63
- context = torch.cat([cond_context, uncond_context])
64
- else:
65
- # Convert into a list of length Seq_Len=77
66
- tokens = tokenizer.batch_encode_plus(
67
- [prompt], padding="max_length", max_length=77
68
- ).input_ids
69
- # (Batch_Size, Seq_Len)
70
- tokens = torch.tensor(tokens, dtype=torch.long, device=device)
71
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
72
- context = clip(tokens)
73
- to_idle(clip)
74
-
75
- if sampler_name == "ddpm":
76
- sampler = DDPMSampler(generator)
77
- sampler.set_inference_timesteps(n_inference_steps)
78
- else:
79
- raise ValueError("Unknown sampler value %s. ")
80
-
81
- latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
82
-
83
- if input_image:
84
- encoder = models["encoder"]
85
- encoder.to(device)
86
-
87
- input_image_tensor = input_image.resize((WIDTH, HEIGHT))
88
- # (Height, Width, Channel)
89
- input_image_tensor = np.array(input_image_tensor)
90
- # (Height, Width, Channel) -> (Height, Width, Channel)
91
- input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
92
- # (Height, Width, Channel) -> (Height, Width, Channel)
93
- input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
94
- # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
95
- input_image_tensor = input_image_tensor.unsqueeze(0)
96
- # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
97
- input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
98
-
99
- # (Batch_Size, 4, Latents_Height, Latents_Width)
100
- encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
101
- # (Batch_Size, 4, Latents_Height, Latents_Width)
102
- latents = encoder(input_image_tensor, encoder_noise)
103
-
104
- # Add noise to the latents (the encoded input image)
105
- # (Batch_Size, 4, Latents_Height, Latents_Width)
106
- sampler.set_strength(strength=strength)
107
- latents = sampler.add_noise(latents, sampler.timesteps[0])
108
-
109
- to_idle(encoder)
110
- else:
111
- # (Batch_Size, 4, Latents_Height, Latents_Width)
112
- latents = torch.randn(latents_shape, generator=generator, device=device)
113
-
114
- diffusion = models["diffusion"]
115
- diffusion.to(device)
116
-
117
- timesteps = tqdm(sampler.timesteps)
118
- for i, timestep in enumerate(timesteps):
119
- # (1, 320)
120
- time_embedding = get_time_embedding(timestep).to(device)
121
-
122
- # (Batch_Size, 4, Latents_Height, Latents_Width)
123
- model_input = latents
124
-
125
- if do_cfg:
126
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
127
- model_input = model_input.repeat(2, 1, 1, 1)
128
-
129
- # model_output is the predicted noise
130
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
131
- model_output = diffusion(model_input, context, time_embedding)
132
-
133
- if do_cfg:
134
- output_cond, output_uncond = model_output.chunk(2)
135
- model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
136
-
137
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
138
- latents = sampler.step(timestep, latents, model_output)
139
-
140
- to_idle(diffusion)
141
-
142
- decoder = models["decoder"]
143
- decoder.to(device)
144
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
145
- images = decoder(latents)
146
- to_idle(decoder)
147
-
148
- images = rescale(images, (-1, 1), (0, 255), clamp=True)
149
- # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
150
- images = images.permute(0, 2, 3, 1)
151
- images = images.to("cpu", torch.uint8).numpy()
152
- return images[0]
153
-
154
- def rescale(x, old_range, new_range, clamp=False):
155
- old_min, old_max = old_range
156
- new_min, new_max = new_range
157
- x -= old_min
158
- x *= (new_max - new_min) / (old_max - old_min)
159
- x += new_min
160
- if clamp:
161
- x = x.clamp(new_min, new_max)
162
- return x
163
-
164
- def get_time_embedding(timestep):
165
- # Shape: (160,)
166
- freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
167
- # Shape: (1, 160)
168
- x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
169
- # Shape: (1, 160 * 2)
170
- return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)