Image-to-3D
Hunyuan3D-2
Diffusers
Safetensors
English
Chinese
text-to-3d
hunyuan3d-paint-v2-0-turbo/unet/diffusion_pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:690a5fc63c4e263ba07dd41dc95f86f702c059c4361b863e2e21af88d8f75714
3
- size 3722674238
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24e7f1aea8a7c94cee627eb06f5265f19eeff4e19568636c5eaef050cc19ba3d
3
+ size 7325432923
hunyuan3d-paint-v2-0-turbo/unet/modules.py CHANGED
@@ -1,13 +1,3 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
  # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
  # except for the third-party components listed below.
13
  # Hunyuan 3D does not impose any additional limitations beyond what is outlined
@@ -22,7 +12,6 @@
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
25
-
26
  import copy
27
  import json
28
  import os
@@ -41,7 +30,9 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
41
  # "feed_forward_chunk_size" can be used to save memory
42
  if hidden_states.shape[chunk_dim] % chunk_size != 0:
43
  raise ValueError(
44
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
 
 
45
  )
46
 
47
  num_chunks = hidden_states.shape[chunk_dim] // chunk_size
@@ -51,329 +42,16 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
51
  )
52
  return ff_output
53
 
54
- class PoseRoPEAttnProcessor2_0:
55
- r"""
56
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
57
- """
58
-
59
- def __init__(self):
60
- if not hasattr(F, "scaled_dot_product_attention"):
61
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
62
-
63
- def get_1d_rotary_pos_embed(
64
- self,
65
- dim: int,
66
- pos: torch.Tensor,
67
- theta: float = 10000.0,
68
- linear_factor=1.0,
69
- ntk_factor=1.0,
70
- ):
71
- assert dim % 2 == 0
72
-
73
- theta = theta * ntk_factor
74
- freqs = (
75
- 1.0
76
- / (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim))
77
- / linear_factor
78
- ) # [D/2]
79
- freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
80
- # flux, hunyuan-dit, cogvideox
81
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
82
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
83
- return freqs_cos, freqs_sin
84
-
85
-
86
- def get_3d_rotary_pos_embed(
87
- self,
88
- position,
89
- embed_dim,
90
- voxel_resolution,
91
- theta: int = 10000,
92
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
93
- """
94
- RoPE for video tokens with 3D structure.
95
-
96
- Args:
97
- voxel_resolution (`int`):
98
- The grid size of the spatial positional embedding (height, width).
99
- theta (`float`):
100
- Scaling factor for frequency computation.
101
-
102
- Returns:
103
- `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
104
- """
105
- assert position.shape[-1]==3
106
-
107
- # Compute dimensions for each axis
108
- dim_xy = embed_dim // 8 * 3
109
- dim_z = embed_dim // 8 * 2
110
-
111
- # Temporal frequencies
112
- grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device)
113
- freqs_xy = self.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta)
114
- freqs_z = self.get_1d_rotary_pos_embed(dim_z, grid, theta=theta)
115
-
116
- xy_cos, xy_sin = freqs_xy # both t_cos and t_sin has shape: voxel_resolution, dim_xy
117
- z_cos, z_sin = freqs_z # both w_cos and w_sin has shape: voxel_resolution, dim_z
118
-
119
- embed_flattn = position.view(-1, position.shape[-1])
120
- x_cos = xy_cos[embed_flattn[:,0], :]
121
- x_sin = xy_sin[embed_flattn[:,0], :]
122
- y_cos = xy_cos[embed_flattn[:,1], :]
123
- y_sin = xy_sin[embed_flattn[:,1], :]
124
- z_cos = z_cos[embed_flattn[:,2], :]
125
- z_sin = z_sin[embed_flattn[:,2], :]
126
-
127
- cos = torch.cat((x_cos, y_cos, z_cos), dim=-1)
128
- sin = torch.cat((x_sin, y_sin, z_sin), dim=-1)
129
-
130
- cos = cos.view(*position.shape[:-1], embed_dim)
131
- sin = sin.view(*position.shape[:-1], embed_dim)
132
- return cos, sin
133
-
134
- def apply_rotary_emb(
135
- self,
136
- x: torch.Tensor,
137
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
138
- ):
139
- cos, sin = freqs_cis # [S, D]
140
- cos, sin = cos.to(x.device), sin.to(x.device)
141
- cos = cos.unsqueeze(1)
142
- sin = sin.unsqueeze(1)
143
-
144
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
145
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
146
-
147
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
148
-
149
- return out
150
-
151
- def __call__(
152
- self,
153
- attn: Attention,
154
- hidden_states: torch.Tensor,
155
- encoder_hidden_states: Optional[torch.Tensor] = None,
156
- attention_mask: Optional[torch.Tensor] = None,
157
- position_indices: Dict = None,
158
- temb: Optional[torch.Tensor] = None,
159
- *args,
160
- **kwargs,
161
- ) -> torch.Tensor:
162
- if len(args) > 0 or kwargs.get("scale", None) is not None:
163
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
164
- deprecate("scale", "1.0.0", deprecation_message)
165
-
166
- residual = hidden_states
167
- if attn.spatial_norm is not None:
168
- hidden_states = attn.spatial_norm(hidden_states, temb)
169
-
170
- input_ndim = hidden_states.ndim
171
-
172
- if input_ndim == 4:
173
- batch_size, channel, height, width = hidden_states.shape
174
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
175
-
176
- batch_size, sequence_length, _ = (
177
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
178
- )
179
-
180
- if attention_mask is not None:
181
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
182
- # scaled_dot_product_attention expects attention_mask shape to be
183
- # (batch, heads, source_length, target_length)
184
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
185
-
186
- if attn.group_norm is not None:
187
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
188
-
189
- query = attn.to_q(hidden_states)
190
-
191
- if encoder_hidden_states is None:
192
- encoder_hidden_states = hidden_states
193
- elif attn.norm_cross:
194
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
195
-
196
- key = attn.to_k(encoder_hidden_states)
197
- value = attn.to_v(encoder_hidden_states)
198
-
199
- inner_dim = key.shape[-1]
200
- head_dim = inner_dim // attn.heads
201
-
202
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
203
-
204
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
205
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
206
-
207
- if attn.norm_q is not None:
208
- query = attn.norm_q(query)
209
- if attn.norm_k is not None:
210
- key = attn.norm_k(key)
211
-
212
- if position_indices is not None:
213
- if head_dim in position_indices:
214
- image_rotary_emb = position_indices[head_dim]
215
- else:
216
- image_rotary_emb = self.get_3d_rotary_pos_embed(position_indices['voxel_indices'], head_dim, voxel_resolution=position_indices['voxel_resolution'])
217
- position_indices[head_dim] = image_rotary_emb
218
- query = self.apply_rotary_emb(query, image_rotary_emb)
219
- key = self.apply_rotary_emb(key, image_rotary_emb)
220
-
221
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
222
- # TODO: add support for attn.scale when we move to Torch 2.1
223
- hidden_states = F.scaled_dot_product_attention(
224
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
225
- )
226
-
227
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
228
- hidden_states = hidden_states.to(query.dtype)
229
-
230
- # linear proj
231
- hidden_states = attn.to_out[0](hidden_states)
232
- # dropout
233
- hidden_states = attn.to_out[1](hidden_states)
234
-
235
- if input_ndim == 4:
236
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
237
-
238
- if attn.residual_connection:
239
- hidden_states = hidden_states + residual
240
-
241
- hidden_states = hidden_states / attn.rescale_output_factor
242
-
243
- return hidden_states
244
-
245
- class IPAttnProcessor2_0:
246
- r"""
247
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
248
- """
249
-
250
- def __init__(self, scale=0.0):
251
- if not hasattr(F, "scaled_dot_product_attention"):
252
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
253
-
254
- self.scale = scale
255
-
256
- def __call__(
257
- self,
258
- attn: Attention,
259
- hidden_states: torch.Tensor,
260
- encoder_hidden_states: Optional[torch.Tensor] = None,
261
- ip_hidden_states: Optional[torch.Tensor] = None,
262
- attention_mask: Optional[torch.Tensor] = None,
263
- temb: Optional[torch.Tensor] = None,
264
- *args,
265
- **kwargs,
266
- ) -> torch.Tensor:
267
- if len(args) > 0 or kwargs.get("scale", None) is not None:
268
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
269
- deprecate("scale", "1.0.0", deprecation_message)
270
-
271
- residual = hidden_states
272
- if attn.spatial_norm is not None:
273
- hidden_states = attn.spatial_norm(hidden_states, temb)
274
-
275
- input_ndim = hidden_states.ndim
276
-
277
- if input_ndim == 4:
278
- batch_size, channel, height, width = hidden_states.shape
279
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
280
-
281
- batch_size, sequence_length, _ = (
282
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
283
- )
284
-
285
- if attention_mask is not None:
286
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
287
- # scaled_dot_product_attention expects attention_mask shape to be
288
- # (batch, heads, source_length, target_length)
289
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
290
-
291
- if attn.group_norm is not None:
292
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
293
-
294
- query = attn.to_q(hidden_states)
295
-
296
- if encoder_hidden_states is None:
297
- encoder_hidden_states = hidden_states
298
- elif attn.norm_cross:
299
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
300
-
301
- key = attn.to_k(encoder_hidden_states)
302
- value = attn.to_v(encoder_hidden_states)
303
-
304
- inner_dim = key.shape[-1]
305
- head_dim = inner_dim // attn.heads
306
-
307
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
308
-
309
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
310
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
311
-
312
- if attn.norm_q is not None:
313
- query = attn.norm_q(query)
314
- if attn.norm_k is not None:
315
- key = attn.norm_k(key)
316
-
317
-
318
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
319
- # TODO: add support for attn.scale when we move to Torch 2.1
320
- hidden_states = F.scaled_dot_product_attention(
321
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
322
- )
323
-
324
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
325
- hidden_states = hidden_states.to(query.dtype)
326
-
327
- # for ip adapter
328
- if ip_hidden_states is not None:
329
-
330
- ip_key = attn.to_k_ip(ip_hidden_states)
331
- ip_value = attn.to_v_ip(ip_hidden_states)
332
-
333
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
334
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
335
-
336
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
337
- ip_hidden_states = F.scaled_dot_product_attention(
338
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
339
- )
340
-
341
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
342
- ip_hidden_states = ip_hidden_states.to(query.dtype)
343
-
344
- hidden_states = hidden_states + self.scale * ip_hidden_states
345
-
346
- # linear proj
347
- hidden_states = attn.to_out[0](hidden_states)
348
- # dropout
349
- hidden_states = attn.to_out[1](hidden_states)
350
-
351
- if input_ndim == 4:
352
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
353
-
354
- if attn.residual_connection:
355
- hidden_states = hidden_states + residual
356
-
357
- hidden_states = hidden_states / attn.rescale_output_factor
358
-
359
- return hidden_states
360
-
361
 
362
  class Basic2p5DTransformerBlock(torch.nn.Module):
363
- def __init__(self, transformer: BasicTransformerBlock, layer_name, use_ipa=True, use_ma=True, use_ra=True) -> None:
364
  super().__init__()
365
  self.transformer = transformer
366
  self.layer_name = layer_name
367
- self.use_ipa = use_ipa
368
  self.use_ma = use_ma
369
  self.use_ra = use_ra
 
370
 
371
- if use_ipa:
372
- self.attn2.set_processor(IPAttnProcessor2_0())
373
- cross_attention_dim = 1024
374
- self.attn2.to_k_ip = nn.Linear(cross_attention_dim, self.dim, bias=False)
375
- self.attn2.to_v_ip = nn.Linear(cross_attention_dim, self.dim, bias=False)
376
-
377
  # multiview attn
378
  if self.use_ma:
379
  self.attn_multiview = Attention(
@@ -385,7 +63,6 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
385
  cross_attention_dim=None,
386
  upcast_attention=self.attn1.upcast_attention,
387
  out_bias=True,
388
- processor=PoseRoPEAttnProcessor2_0(),
389
  )
390
 
391
  # ref attn
@@ -400,8 +77,8 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
400
  upcast_attention=self.attn1.upcast_attention,
401
  out_bias=True,
402
  )
403
-
404
- self._initialize_attn_weights()
405
 
406
  def _initialize_attn_weights(self):
407
 
@@ -418,10 +95,6 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
418
  for param in layer.parameters():
419
  param.zero_()
420
 
421
- if self.use_ipa:
422
- self.attn2.to_k_ip.load_state_dict(self.attn2.to_k.state_dict())
423
- self.attn2.to_v_ip.load_state_dict(self.attn2.to_v.state_dict())
424
-
425
  def __getattr__(self, name: str):
426
  try:
427
  return super().__getattr__(name)
@@ -447,10 +120,16 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
447
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
448
  num_in_batch = cross_attention_kwargs.pop('num_in_batch', 1)
449
  mode = cross_attention_kwargs.pop('mode', None)
 
 
 
 
 
 
 
 
 
450
  condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None)
451
- ip_hidden_states = cross_attention_kwargs.pop("ip_hidden_states", None)
452
- position_attn_mask = cross_attention_kwargs.pop("position_attn_mask", None)
453
- position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None)
454
 
455
  if self.norm_type == "ada_norm":
456
  norm_hidden_states = self.norm1(hidden_states, timestep)
@@ -470,10 +149,10 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
470
  norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
471
  else:
472
  raise ValueError("Incorrect norm used")
473
-
474
  if self.pos_embed is not None:
475
  norm_hidden_states = self.pos_embed(norm_hidden_states)
476
-
477
  # 1. Prepare GLIGEN inputs
478
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
479
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
@@ -484,6 +163,7 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
484
  attention_mask=attention_mask,
485
  **cross_attention_kwargs,
486
  )
 
487
  if self.norm_type == "ada_norm_zero":
488
  attn_output = gate_msa.unsqueeze(1) * attn_output
489
  elif self.norm_type == "ada_norm_single":
@@ -492,13 +172,17 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
492
  hidden_states = attn_output + hidden_states
493
  if hidden_states.ndim == 4:
494
  hidden_states = hidden_states.squeeze(1)
495
-
496
  # 1.2 Reference Attention
497
  if 'w' in mode:
498
- condition_embed_dict[self.layer_name] = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch) # B, (N L), C
499
-
500
- if 'r' in mode:
501
- condition_embed = condition_embed_dict[self.layer_name].unsqueeze(1).repeat(1,num_in_batch,1,1) # B N L C
 
 
 
 
502
  condition_embed = rearrange(condition_embed, 'b n l c -> (b n) l c')
503
 
504
  attn_output = self.attn_refview(
@@ -507,35 +191,48 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
507
  attention_mask=None,
508
  **cross_attention_kwargs
509
  )
 
 
 
 
 
 
 
 
510
 
511
- hidden_states = attn_output + hidden_states
512
  if hidden_states.ndim == 4:
513
  hidden_states = hidden_states.squeeze(1)
514
-
515
 
516
  # 1.3 Multiview Attention
517
  if num_in_batch > 1 and self.use_ma:
518
  multivew_hidden_states = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch)
519
- position_mask = None
520
- if position_attn_mask is not None:
521
- if multivew_hidden_states.shape[1] in position_attn_mask:
522
- position_mask = position_attn_mask[multivew_hidden_states.shape[1]]
523
- position_indices = None
524
- if position_voxel_indices is not None:
525
- if multivew_hidden_states.shape[1] in position_voxel_indices:
526
- position_indices = position_voxel_indices[multivew_hidden_states.shape[1]]
527
-
528
- attn_output = self.attn_multiview(
529
- multivew_hidden_states,
530
- encoder_hidden_states=multivew_hidden_states,
531
- attention_mask=position_mask,
532
- position_indices=position_indices,
533
- **cross_attention_kwargs
534
- )
535
 
536
- attn_output = rearrange(attn_output, 'b (n l) c -> (b n) l c', n=num_in_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
- hidden_states = attn_output + hidden_states
 
 
539
  if hidden_states.ndim == 4:
540
  hidden_states = hidden_states.squeeze(1)
541
 
@@ -561,25 +258,12 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
561
  if self.pos_embed is not None and self.norm_type != "ada_norm_single":
562
  norm_hidden_states = self.pos_embed(norm_hidden_states)
563
 
564
- if ip_hidden_states is not None:
565
- ip_hidden_states = ip_hidden_states.unsqueeze(1).repeat(1,num_in_batch,1,1) # B N L C
566
- ip_hidden_states = rearrange(ip_hidden_states, 'b n l c -> (b n) l c')
567
-
568
- if self.use_ipa:
569
- attn_output = self.attn2(
570
- norm_hidden_states,
571
- encoder_hidden_states=encoder_hidden_states,
572
- ip_hidden_states=ip_hidden_states,
573
- attention_mask=encoder_attention_mask,
574
- **cross_attention_kwargs,
575
- )
576
- else:
577
- attn_output = self.attn2(
578
- norm_hidden_states,
579
- encoder_hidden_states=encoder_hidden_states,
580
- attention_mask=encoder_attention_mask,
581
- **cross_attention_kwargs,
582
- )
583
 
584
  hidden_states = attn_output + hidden_states
585
 
@@ -626,8 +310,16 @@ def compute_voxel_grid_mask(position, grid_resolution=8):
626
  position[valid_mask==False] = 0
627
 
628
 
629
- position = rearrange(position, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution)
630
- valid_mask = rearrange(valid_mask, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution)
 
 
 
 
 
 
 
 
631
 
632
  grid_position = position.sum(dim=(-2, -1))
633
  count_masked = valid_mask.sum(dim=(-2, -1))
@@ -674,8 +366,16 @@ def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution=
674
  valid_mask = valid_mask.expand_as(position)
675
  position[valid_mask==False] = 0
676
 
677
- position = rearrange(position, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution)
678
- valid_mask = rearrange(valid_mask, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution)
 
 
 
 
 
 
 
 
679
 
680
  grid_position = position.sum(dim=(-2, -1))
681
  count_masked = valid_mask.sum(dim=(-2, -1))
@@ -688,45 +388,36 @@ def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution=
688
  voxel_indices = torch.round(voxel_indices).long()
689
  return voxel_indices
690
 
691
- def compute_multi_resolution_discrete_voxel_indice(position_maps, grid_resolutions=[64, 32, 16, 8], voxel_resolutions=[512, 256, 128, 64]):
 
 
 
 
692
  voxel_indices = {}
693
  with torch.no_grad():
694
  for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions):
695
  voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution)
696
  voxel_indice = rearrange(voxel_indice, 'b n c h w -> b (n h w) c')
697
  voxel_indices[voxel_indice.shape[1]] = {'voxel_indices':voxel_indice, 'voxel_resolution':voxel_resolution}
698
- return voxel_indices
699
-
700
- class ImageProjModel(torch.nn.Module):
701
- """Projection Model"""
702
-
703
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
704
- super().__init__()
705
-
706
- self.generator = None
707
- self.cross_attention_dim = cross_attention_dim
708
- self.clip_extra_context_tokens = clip_extra_context_tokens
709
- self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
710
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
711
 
712
- def forward(self, image_embeds):
713
- embeds = image_embeds
714
- clip_extra_context_tokens = self.proj(embeds).reshape(
715
- -1, self.clip_extra_context_tokens, self.cross_attention_dim
716
- )
717
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
718
- return clip_extra_context_tokens
719
-
720
  class UNet2p5DConditionModel(torch.nn.Module):
721
  def __init__(self, unet: UNet2DConditionModel) -> None:
722
  super().__init__()
723
  self.unet = unet
724
- self.unet_dual = copy.deepcopy(unet)
725
 
726
- self.init_camera_embedding()
727
- self.init_attention(self.unet, use_ipa=True, use_ma=True, use_ra=True)
728
- self.init_attention(self.unet_dual, use_ipa=False, use_ma=False, use_ra=False)
 
 
 
 
 
 
 
729
  self.init_condition()
 
730
 
731
  @staticmethod
732
  def from_pretrained(pretrained_model_name_or_path, **kwargs):
@@ -737,170 +428,158 @@ class UNet2p5DConditionModel(torch.nn.Module):
737
  config = json.load(file)
738
  unet = UNet2DConditionModel(**config)
739
  unet = UNet2p5DConditionModel(unet)
740
-
741
- unet.unet.conv_in = torch.nn.Conv2d(
742
- 12,
743
- unet.unet.conv_in.out_channels,
744
- kernel_size=unet.unet.conv_in.kernel_size,
745
- stride=unet.unet.conv_in.stride,
746
- padding=unet.unet.conv_in.padding,
747
- dilation=unet.unet.conv_in.dilation,
748
- groups=unet.unet.conv_in.groups,
749
- bias=unet.unet.conv_in.bias is not None)
750
-
751
  unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
752
  unet.load_state_dict(unet_ckpt, strict=True)
753
  unet = unet.to(torch_dtype)
754
  return unet
755
-
756
- def init_condition(self):
757
- self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1,77,1024))
758
- self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1,77,1024))
759
 
760
- self.unet.image_proj_model = ImageProjModel(
761
- cross_attention_dim=self.unet.config.cross_attention_dim,
762
- clip_embeddings_dim=1024,
763
- )
 
 
 
 
 
 
764
 
 
 
765
 
766
  def init_camera_embedding(self):
767
- self.max_num_ref_image = 5
768
- self.max_num_gen_image = 12*3+4*2
769
 
770
- time_embed_dim = 1280
771
- self.unet.class_embedding = nn.Embedding(self.max_num_ref_image+self.max_num_gen_image, time_embed_dim)
772
- # 将嵌入层的权重初始化为全零
773
- nn.init.zeros_(self.unet.class_embedding.weight)
774
-
775
- def init_attention(self, unet, use_ipa=True, use_ma=True, use_ra=True):
 
776
 
777
  for down_block_i, down_block in enumerate(unet.down_blocks):
778
  if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
779
  for attn_i, attn in enumerate(down_block.attentions):
780
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
781
  if isinstance(transformer, BasicTransformerBlock):
782
- attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'down_{down_block_i}_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra)
 
 
 
 
783
 
784
  if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
785
  for attn_i, attn in enumerate(unet.mid_block.attentions):
786
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
787
  if isinstance(transformer, BasicTransformerBlock):
788
- attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'mid_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra)
 
 
 
 
789
 
790
  for up_block_i, up_block in enumerate(unet.up_blocks):
791
  if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention:
792
  for attn_i, attn in enumerate(up_block.attentions):
793
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
794
  if isinstance(transformer, BasicTransformerBlock):
795
- attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'up_{up_block_i}_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra)
796
-
 
 
 
797
 
798
  def __getattr__(self, name: str):
799
  try:
800
  return super().__getattr__(name)
801
  except AttributeError:
802
  return getattr(self.unet, name)
803
-
804
  def forward(
805
- self, sample, timestep, encoder_hidden_states, class_labels=None,
806
- *args, cross_attention_kwargs=None, down_intrablock_additional_residuals=None,
807
  down_block_res_samples=None, mid_block_res_sample=None,
808
  **cached_condition,
809
  ):
810
  B, N_gen, _, H, W = sample.shape
811
- camera_info_gen = cached_condition['camera_info_gen'] + self.max_num_ref_image
812
- camera_info_gen = rearrange(camera_info_gen, 'b n -> (b n)')
 
 
 
 
 
 
813
  sample = [sample]
814
-
815
  if 'normal_imgs' in cached_condition:
816
  sample.append(cached_condition["normal_imgs"])
817
  if 'position_imgs' in cached_condition:
818
  sample.append(cached_condition["position_imgs"])
819
-
820
  sample = torch.cat(sample, dim=2)
 
821
  sample = rearrange(sample, 'b n c h w -> (b n) c h w')
822
 
823
  encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(1, N_gen, 1, 1)
824
  encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, 'b n l c -> (b n) l c')
825
-
826
-
827
- use_position_mask = False
828
- use_position_rope = True
829
-
830
- position_attn_mask = None
831
- if use_position_mask:
832
- if 'position_attn_mask' in cached_condition:
833
- position_attn_mask = cached_condition['position_attn_mask']
834
- else:
835
- if 'position_maps' in cached_condition:
836
- position_attn_mask = compute_multi_resolution_mask(cached_condition['position_maps'])
837
-
838
- position_voxel_indices = None
839
- if use_position_rope:
840
- if 'position_voxel_indices' in cached_condition:
841
- position_voxel_indices = cached_condition['position_voxel_indices']
842
- else:
843
- if 'position_maps' in cached_condition:
844
- position_voxel_indices = compute_multi_resolution_discrete_voxel_indice(cached_condition['position_maps'])
845
 
846
- if 'ip_hidden_states' in cached_condition:
847
- ip_hidden_states = cached_condition['ip_hidden_states']
848
- else:
849
- if 'clip_embeds' in cached_condition:
850
- ip_hidden_states = self.image_proj_model(cached_condition['clip_embeds'])
851
  else:
852
- ip_hidden_states = None
853
- cached_condition['ip_hidden_states'] = ip_hidden_states
854
-
855
- if 'condition_embed_dict' in cached_condition:
856
- condition_embed_dict = cached_condition['condition_embed_dict']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
  else:
858
- condition_embed_dict = {}
859
- ref_latents = cached_condition['ref_latents']
860
- N_ref = ref_latents.shape[1]
861
- camera_info_ref = cached_condition['camera_info_ref']
862
- camera_info_ref = rearrange(camera_info_ref, 'b n -> (b n)')
863
-
864
- #ref_latents = [ref_latents]
865
- #if 'normal_imgs' in cached_condition:
866
- # ref_latents.append(torch.zeros_like(ref_latents[0]))
867
- #if 'position_imgs' in cached_condition:
868
- # ref_latents.append(torch.zeros_like(ref_latents[0]))
869
- #ref_latents = torch.cat(ref_latents, dim=2)
870
-
871
- ref_latents = rearrange(ref_latents, 'b n c h w -> (b n) c h w')
872
 
873
- encoder_hidden_states_ref = self.learned_text_clip_ref.unsqueeze(1).repeat(B, N_ref, 1, 1)
874
- encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, 'b n l c -> (b n) l c')
875
-
876
- noisy_ref_latents = ref_latents
877
- timestep_ref = 0
878
- '''
879
- if timestep.dim()>0:
880
- timestep_ref = rearrange(timestep, '(b n) -> b n', b=B)[:,:1].repeat(1, N_ref)
881
- timestep_ref = rearrange(timestep_ref, 'b n -> (b n)')
882
- else:
883
- timestep_ref = timestep
884
- noise = torch.randn_like(noisy_ref_latents[:,:4,...])
885
- if self.training:
886
- noisy_ref_latents[:,:4,...] = self.train_sched.add_noise(noisy_ref_latents[:,:4,...], noise, timestep_ref)
887
- noisy_ref_latents[:,:4,...] = self.train_sched.scale_model_input(noisy_ref_latents[:,:4,...], timestep_ref)
888
- else:
889
- noisy_ref_latents[:,:4,...] = self.val_sched.add_noise(noisy_ref_latents[:,:4,...], noise, timestep_ref.reshape(-1))
890
- noisy_ref_latents[:,:4,...] = self.val_sched.scale_model_input(noisy_ref_latents[:,:4,...], timestep_ref.reshape(-1))
891
- '''
892
- self.unet_dual(
893
- noisy_ref_latents, timestep_ref,
894
- encoder_hidden_states=encoder_hidden_states_ref,
895
- #class_labels=camera_info_ref,
896
- # **kwargs
897
- return_dict=False,
898
- cross_attention_kwargs={
899
- 'mode':'w', 'num_in_batch':N_ref,
900
- 'condition_embed_dict':condition_embed_dict},
901
- )
902
- cached_condition['condition_embed_dict'] = condition_embed_dict
903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
  return self.unet(
905
  sample, timestep,
906
  encoder_hidden_states_gen, *args,
@@ -916,11 +595,6 @@ class UNet2p5DConditionModel(torch.nn.Module):
916
  if mid_block_res_sample is not None else None
917
  ),
918
  return_dict=False,
919
- cross_attention_kwargs={
920
- 'mode':'r', 'num_in_batch':N_gen,
921
- 'ip_hidden_states':ip_hidden_states,
922
- 'condition_embed_dict':condition_embed_dict,
923
- 'position_attn_mask':position_attn_mask,
924
- 'position_voxel_indices':position_voxel_indices
925
- },
926
- )
 
 
 
 
 
 
 
 
 
 
 
1
  # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
  # except for the third-party components listed below.
3
  # Hunyuan 3D does not impose any additional limitations beyond what is outlined
 
12
  # fine-tuning enabling code and other elements of the foregoing made publicly available
13
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
 
 
15
  import copy
16
  import json
17
  import os
 
30
  # "feed_forward_chunk_size" can be used to save memory
31
  if hidden_states.shape[chunk_dim] % chunk_size != 0:
32
  raise ValueError(
33
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]}"
34
+ f"has to be divisible by chunk size: {chunk_size}."
35
+ f" Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
  )
37
 
38
  num_chunks = hidden_states.shape[chunk_dim] // chunk_size
 
42
  )
43
  return ff_output
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  class Basic2p5DTransformerBlock(torch.nn.Module):
47
+ def __init__(self, transformer: BasicTransformerBlock, layer_name, use_ma=True, use_ra=True, is_turbo=False) -> None:
48
  super().__init__()
49
  self.transformer = transformer
50
  self.layer_name = layer_name
 
51
  self.use_ma = use_ma
52
  self.use_ra = use_ra
53
+ self.is_turbo = is_turbo
54
 
 
 
 
 
 
 
55
  # multiview attn
56
  if self.use_ma:
57
  self.attn_multiview = Attention(
 
63
  cross_attention_dim=None,
64
  upcast_attention=self.attn1.upcast_attention,
65
  out_bias=True,
 
66
  )
67
 
68
  # ref attn
 
77
  upcast_attention=self.attn1.upcast_attention,
78
  out_bias=True,
79
  )
80
+ if self.is_turbo:
81
+ self._initialize_attn_weights()
82
 
83
  def _initialize_attn_weights(self):
84
 
 
95
  for param in layer.parameters():
96
  param.zero_()
97
 
 
 
 
 
98
  def __getattr__(self, name: str):
99
  try:
100
  return super().__getattr__(name)
 
120
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
121
  num_in_batch = cross_attention_kwargs.pop('num_in_batch', 1)
122
  mode = cross_attention_kwargs.pop('mode', None)
123
+ if not self.is_turbo:
124
+ mva_scale = cross_attention_kwargs.pop('mva_scale', 1.0)
125
+ ref_scale = cross_attention_kwargs.pop('ref_scale', 1.0)
126
+ else:
127
+ position_attn_mask = cross_attention_kwargs.pop("position_attn_mask", None)
128
+ position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None)
129
+ mva_scale = 1.0
130
+ ref_scale = 1.0
131
+
132
  condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None)
 
 
 
133
 
134
  if self.norm_type == "ada_norm":
135
  norm_hidden_states = self.norm1(hidden_states, timestep)
 
149
  norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
150
  else:
151
  raise ValueError("Incorrect norm used")
152
+
153
  if self.pos_embed is not None:
154
  norm_hidden_states = self.pos_embed(norm_hidden_states)
155
+
156
  # 1. Prepare GLIGEN inputs
157
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
158
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
 
163
  attention_mask=attention_mask,
164
  **cross_attention_kwargs,
165
  )
166
+
167
  if self.norm_type == "ada_norm_zero":
168
  attn_output = gate_msa.unsqueeze(1) * attn_output
169
  elif self.norm_type == "ada_norm_single":
 
172
  hidden_states = attn_output + hidden_states
173
  if hidden_states.ndim == 4:
174
  hidden_states = hidden_states.squeeze(1)
175
+
176
  # 1.2 Reference Attention
177
  if 'w' in mode:
178
+ condition_embed_dict[self.layer_name] = rearrange(
179
+ norm_hidden_states, '(b n) l c -> b (n l) c',
180
+ n=num_in_batch
181
+ ) # B, (N L), C
182
+
183
+ if 'r' in mode and self.use_ra:
184
+ condition_embed = condition_embed_dict[self.layer_name].unsqueeze(1).repeat(1, num_in_batch, 1,
185
+ 1) # B N L C
186
  condition_embed = rearrange(condition_embed, 'b n l c -> (b n) l c')
187
 
188
  attn_output = self.attn_refview(
 
191
  attention_mask=None,
192
  **cross_attention_kwargs
193
  )
194
+ if not self.is_turbo:
195
+ ref_scale_timing = ref_scale
196
+ if isinstance(ref_scale, torch.Tensor):
197
+ ref_scale_timing = ref_scale.unsqueeze(1).repeat(1, num_in_batch).view(-1)
198
+ for _ in range(attn_output.ndim - 1):
199
+ ref_scale_timing = ref_scale_timing.unsqueeze(-1)
200
+
201
+ hidden_states = ref_scale_timing * attn_output + hidden_states
202
 
 
203
  if hidden_states.ndim == 4:
204
  hidden_states = hidden_states.squeeze(1)
 
205
 
206
  # 1.3 Multiview Attention
207
  if num_in_batch > 1 and self.use_ma:
208
  multivew_hidden_states = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ if self.is_turbo:
211
+ position_mask = None
212
+ if position_attn_mask is not None:
213
+ if multivew_hidden_states.shape[1] in position_attn_mask:
214
+ position_mask = position_attn_mask[multivew_hidden_states.shape[1]]
215
+ position_indices = None
216
+ if position_voxel_indices is not None:
217
+ if multivew_hidden_states.shape[1] in position_voxel_indices:
218
+ position_indices = position_voxel_indices[multivew_hidden_states.shape[1]]
219
+ attn_output = self.attn_multiview(
220
+ multivew_hidden_states,
221
+ encoder_hidden_states=multivew_hidden_states,
222
+ attention_mask=position_mask,
223
+ position_indices=position_indices,
224
+ **cross_attention_kwargs
225
+ )
226
+ else:
227
+ attn_output = self.attn_multiview(
228
+ multivew_hidden_states,
229
+ encoder_hidden_states=multivew_hidden_states,
230
+ **cross_attention_kwargs
231
+ )
232
 
233
+ attn_output = rearrange(attn_output, 'b (n l) c -> (b n) l c', n=num_in_batch)
234
+
235
+ hidden_states = mva_scale * attn_output + hidden_states
236
  if hidden_states.ndim == 4:
237
  hidden_states = hidden_states.squeeze(1)
238
 
 
258
  if self.pos_embed is not None and self.norm_type != "ada_norm_single":
259
  norm_hidden_states = self.pos_embed(norm_hidden_states)
260
 
261
+ attn_output = self.attn2(
262
+ norm_hidden_states,
263
+ encoder_hidden_states=encoder_hidden_states,
264
+ attention_mask=encoder_attention_mask,
265
+ **cross_attention_kwargs,
266
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  hidden_states = attn_output + hidden_states
269
 
 
310
  position[valid_mask==False] = 0
311
 
312
 
313
+ position = rearrange(
314
+ position,
315
+ 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w',
316
+ num_h=grid_resolution, num_w=grid_resolution
317
+ )
318
+ valid_mask = rearrange(
319
+ valid_mask,
320
+ 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w',
321
+ num_h=grid_resolution, num_w=grid_resolution
322
+ )
323
 
324
  grid_position = position.sum(dim=(-2, -1))
325
  count_masked = valid_mask.sum(dim=(-2, -1))
 
366
  valid_mask = valid_mask.expand_as(position)
367
  position[valid_mask==False] = 0
368
 
369
+ position = rearrange(
370
+ position,
371
+ 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w',
372
+ num_h=grid_resolution, num_w=grid_resolution
373
+ )
374
+ valid_mask = rearrange(
375
+ valid_mask,
376
+ 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w',
377
+ num_h=grid_resolution, num_w=grid_resolution
378
+ )
379
 
380
  grid_position = position.sum(dim=(-2, -1))
381
  count_masked = valid_mask.sum(dim=(-2, -1))
 
388
  voxel_indices = torch.round(voxel_indices).long()
389
  return voxel_indices
390
 
391
+ def compute_multi_resolution_discrete_voxel_indice(
392
+ position_maps,
393
+ grid_resolutions=[64, 32, 16, 8],
394
+ voxel_resolutions=[512, 256, 128, 64]
395
+ ):
396
  voxel_indices = {}
397
  with torch.no_grad():
398
  for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions):
399
  voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution)
400
  voxel_indice = rearrange(voxel_indice, 'b n c h w -> b (n h w) c')
401
  voxel_indices[voxel_indice.shape[1]] = {'voxel_indices':voxel_indice, 'voxel_resolution':voxel_resolution}
402
+ return voxel_indices
 
 
 
 
 
 
 
 
 
 
 
 
403
 
 
 
 
 
 
 
 
 
404
  class UNet2p5DConditionModel(torch.nn.Module):
405
  def __init__(self, unet: UNet2DConditionModel) -> None:
406
  super().__init__()
407
  self.unet = unet
 
408
 
409
+ self.use_ma = True
410
+ self.use_ra = True
411
+ self.use_camera_embedding = True
412
+ self.use_dual_stream = True
413
+ self.is_turbo = False
414
+
415
+ if self.use_dual_stream:
416
+ self.unet_dual = copy.deepcopy(unet)
417
+ self.init_attention(self.unet_dual)
418
+ self.init_attention(self.unet, use_ma=self.use_ma, use_ra=self.use_ra, is_turbo=self.is_turbo)
419
  self.init_condition()
420
+ self.init_camera_embedding()
421
 
422
  @staticmethod
423
  def from_pretrained(pretrained_model_name_or_path, **kwargs):
 
428
  config = json.load(file)
429
  unet = UNet2DConditionModel(**config)
430
  unet = UNet2p5DConditionModel(unet)
 
 
 
 
 
 
 
 
 
 
 
431
  unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
432
  unet.load_state_dict(unet_ckpt, strict=True)
433
  unet = unet.to(torch_dtype)
434
  return unet
 
 
 
 
435
 
436
+ def init_condition(self):
437
+ self.unet.conv_in = torch.nn.Conv2d(
438
+ 12,
439
+ self.unet.conv_in.out_channels,
440
+ kernel_size=self.unet.conv_in.kernel_size,
441
+ stride=self.unet.conv_in.stride,
442
+ padding=self.unet.conv_in.padding,
443
+ dilation=self.unet.conv_in.dilation,
444
+ groups=self.unet.conv_in.groups,
445
+ bias=self.unet.conv_in.bias is not None)
446
 
447
+ self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1, 77, 1024))
448
+ self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1, 77, 1024))
449
 
450
  def init_camera_embedding(self):
 
 
451
 
452
+ if self.use_camera_embedding:
453
+ time_embed_dim = 1280
454
+ self.max_num_ref_image = 5
455
+ self.max_num_gen_image = 12 * 3 + 4 * 2
456
+ self.unet.class_embedding = nn.Embedding(self.max_num_ref_image + self.max_num_gen_image, time_embed_dim)
457
+
458
+ def init_attention(self, unet, use_ma=False, use_ra=False, is_turbo=False):
459
 
460
  for down_block_i, down_block in enumerate(unet.down_blocks):
461
  if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
462
  for attn_i, attn in enumerate(down_block.attentions):
463
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
464
  if isinstance(transformer, BasicTransformerBlock):
465
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
466
+ transformer,
467
+ f'down_{down_block_i}_{attn_i}_{transformer_i}',
468
+ use_ma, use_ra, is_turbo
469
+ )
470
 
471
  if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
472
  for attn_i, attn in enumerate(unet.mid_block.attentions):
473
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
474
  if isinstance(transformer, BasicTransformerBlock):
475
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
476
+ transformer,
477
+ f'mid_{attn_i}_{transformer_i}',
478
+ use_ma, use_ra, is_turbo
479
+ )
480
 
481
  for up_block_i, up_block in enumerate(unet.up_blocks):
482
  if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention:
483
  for attn_i, attn in enumerate(up_block.attentions):
484
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
485
  if isinstance(transformer, BasicTransformerBlock):
486
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
487
+ transformer,
488
+ f'up_{up_block_i}_{attn_i}_{transformer_i}',
489
+ use_ma, use_ra, is_turbo
490
+ )
491
 
492
  def __getattr__(self, name: str):
493
  try:
494
  return super().__getattr__(name)
495
  except AttributeError:
496
  return getattr(self.unet, name)
497
+
498
  def forward(
499
+ self, sample, timestep, encoder_hidden_states,
500
+ *args, down_intrablock_additional_residuals=None,
501
  down_block_res_samples=None, mid_block_res_sample=None,
502
  **cached_condition,
503
  ):
504
  B, N_gen, _, H, W = sample.shape
505
+ assert H == W
506
+
507
+ if self.use_camera_embedding:
508
+ camera_info_gen = cached_condition['camera_info_gen'] + self.max_num_ref_image
509
+ camera_info_gen = rearrange(camera_info_gen, 'b n -> (b n)')
510
+ else:
511
+ camera_info_gen = None
512
+
513
  sample = [sample]
 
514
  if 'normal_imgs' in cached_condition:
515
  sample.append(cached_condition["normal_imgs"])
516
  if 'position_imgs' in cached_condition:
517
  sample.append(cached_condition["position_imgs"])
 
518
  sample = torch.cat(sample, dim=2)
519
+
520
  sample = rearrange(sample, 'b n c h w -> (b n) c h w')
521
 
522
  encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(1, N_gen, 1, 1)
523
  encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, 'b n l c -> (b n) l c')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
+ if self.use_ra:
526
+ if 'condition_embed_dict' in cached_condition:
527
+ condition_embed_dict = cached_condition['condition_embed_dict']
 
 
528
  else:
529
+ condition_embed_dict = {}
530
+ ref_latents = cached_condition['ref_latents']
531
+ N_ref = ref_latents.shape[1]
532
+ if self.use_camera_embedding:
533
+ camera_info_ref = cached_condition['camera_info_ref']
534
+ camera_info_ref = rearrange(camera_info_ref, 'b n -> (b n)')
535
+ else:
536
+ camera_info_ref = None
537
+
538
+ ref_latents = rearrange(ref_latents, 'b n c h w -> (b n) c h w')
539
+
540
+ encoder_hidden_states_ref = self.unet.learned_text_clip_ref.unsqueeze(1).repeat(B, N_ref, 1, 1)
541
+ encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, 'b n l c -> (b n) l c')
542
+
543
+ noisy_ref_latents = ref_latents
544
+ timestep_ref = 0
545
+
546
+ if self.use_dual_stream:
547
+ unet_ref = self.unet_dual
548
+ else:
549
+ unet_ref = self.unet
550
+ unet_ref(
551
+ noisy_ref_latents, timestep_ref,
552
+ encoder_hidden_states=encoder_hidden_states_ref,
553
+ class_labels=camera_info_ref,
554
+ # **kwargs
555
+ return_dict=False,
556
+ cross_attention_kwargs={
557
+ 'mode': 'w', 'num_in_batch': N_ref,
558
+ 'condition_embed_dict': condition_embed_dict},
559
+ )
560
+ cached_condition['condition_embed_dict'] = condition_embed_dict
561
  else:
562
+ condition_embed_dict = None
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
+ mva_scale = cached_condition.get('mva_scale', 1.0)
565
+ ref_scale = cached_condition.get('ref_scale', 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
 
567
+ if self.is_turbo:
568
+ cross_attention_kwargs_ = {
569
+ 'mode': 'r', 'num_in_batch': N_gen,
570
+ 'condition_embed_dict': condition_embed_dict,
571
+ 'position_attn_mask':position_attn_mask,
572
+ 'position_voxel_indices':position_voxel_indices,
573
+ 'mva_scale': mva_scale,
574
+ 'ref_scale': ref_scale,
575
+ }
576
+ else:
577
+ cross_attention_kwargs_ = {
578
+ 'mode': 'r', 'num_in_batch': N_gen,
579
+ 'condition_embed_dict': condition_embed_dict,
580
+ 'mva_scale': mva_scale,
581
+ 'ref_scale': ref_scale,
582
+ }
583
  return self.unet(
584
  sample, timestep,
585
  encoder_hidden_states_gen, *args,
 
595
  if mid_block_res_sample is not None else None
596
  ),
597
  return_dict=False,
598
+ cross_attention_kwargs=cross_attention_kwargs_,
599
+ )
600
+