Diffusers
Safetensors
x-omni
custom_code
zhangxiaosong18 commited on
Commit
19854ad
·
verified ·
1 Parent(s): 15f6a54

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "XOmniForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_xomni.XOmniConfig",
8
+ "AutoModel": "modeling_xomni.XOmniModel",
9
+ "AutoModelForCausalLM": "modeling_xomni.XOmniForCausalLM"
10
+ },
11
+ "eos_token_id": 151643,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 3584,
14
+ "image_vocab_size": 16384,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 18944,
17
+ "max_position_embeddings": 8192,
18
+ "max_window_layers": 28,
19
+ "mm_special_tokens": [
20
+ "<SOM>",
21
+ "<EOM>",
22
+ "<IMAGE>"
23
+ ],
24
+ "mm_vocab_size": 16448,
25
+ "model_type": "x-omni",
26
+ "num_attention_heads": 28,
27
+ "num_hidden_layers": 36,
28
+ "num_key_value_heads": 4,
29
+ "num_mm_adap_layers": 4,
30
+ "num_mm_head_layers": 4,
31
+ "pad_token_id": 151643,
32
+ "rms_norm_eps": 1e-06,
33
+ "rope_scaling": null,
34
+ "rope_theta": 1000000.0,
35
+ "sliding_window": null,
36
+ "tie_word_embeddings": false,
37
+ "torch_dtype": "bfloat16",
38
+ "transformers_version": "4.52.0",
39
+ "use_cache": true,
40
+ "use_sliding_window": false,
41
+ "vision_config":{
42
+ "transform": {
43
+ "short_size": 384,
44
+ "long_size": 1152,
45
+ "patch_size": 16,
46
+ "random_ratio": null,
47
+ "min_short_size": 128,
48
+ "max_aspect_ratio": 3.0,
49
+ "filtering": false
50
+ },
51
+ "encoder": {
52
+ "siglip_name": "siglip2_giant_patch16_384",
53
+ "siglip_path": "vit/vit_g.pth",
54
+ "projector_path": "vit/siglip_vq.pt",
55
+ "with_norm": true,
56
+ "z_channels": 1536,
57
+ "codebook_size": 16384,
58
+ "codebook_dim": 2048
59
+ },
60
+ "decoder": {
61
+ "model_path": "diffusers",
62
+ "num_inference_steps": 28,
63
+ "cfg_scale": 1.5,
64
+ "cfg_scale_2": 1.5,
65
+ "upscale_factor": 16
66
+ },
67
+ "dtype": "bfloat16"
68
+ },
69
+ "vocab_size": 151936
70
+ }
configuration_xomni.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, Qwen2Config
2
+ from typing import Tuple
3
+
4
+
5
+ class XOmniConfig(Qwen2Config):
6
+ model_type = "x-omni"
7
+
8
+ def __init__(
9
+ self,
10
+ num_mm_adap_layers: int = 4,
11
+ num_mm_head_layers: int = 4,
12
+ mm_vocab_size: int = 16448,
13
+ image_vocab_size: int = 16384,
14
+ mm_special_tokens: Tuple[str] = ('<SOM>', '<EOM>', '<IMAGE>'),
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.num_mm_adap_layers = num_mm_adap_layers
19
+ self.num_mm_head_layers = num_mm_head_layers
20
+ self.mm_vocab_size = mm_vocab_size
21
+ self.image_vocab_size = image_vocab_size
22
+ self.mm_special_tokens = mm_special_tokens
23
+
24
+
25
+ AutoConfig.register("x-omni", XOmniConfig)
diffusers/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_diffusers_version": "0.33.0.dev0",
3
+ "attention_head_dim": 128,
4
+ "axes_dims_rope": [
5
+ 16,
6
+ 56,
7
+ 56
8
+ ],
9
+ "drop_token_prob": 0.0,
10
+ "guidance_embeds": true,
11
+ "hidden_size": null,
12
+ "in_channels": 64,
13
+ "joint_attention_dim": 4096,
14
+ "num_attention_heads": 24,
15
+ "num_layers": 19,
16
+ "num_single_layers": 38,
17
+ "out_channels": null,
18
+ "patch_size": 1,
19
+ "pooled_projection_dim": 768,
20
+ "siglip_channels": null
21
+ }
diffusers/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9ae32ab1c024a19aaf050f0b2eb38bb8ea069b2298e922cd8d2d8328445387e
3
+ size 23812392896
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.52.0"
4
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a3201f2f35a4f3176430b949a6d8c90913eb2bf0b0398ee896fcbcf24bdab84
3
+ size 4994657152
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3918a1e7bbba2b804e6ea9c5298f3b6f419ca1093efbaee3ccf20880abc07c0c
3
+ size 4932751016
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ba948766a1cd1e285aba59caa3abfabb7a8cce71743366c3b0f4f3434a6655b
3
+ size 4991495904
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cd5fe02d25527abb4f4f47d6d74b230554a6cf931062496f49d3cd202db2bd1
3
+ size 4275274456
model.safetensors.index.json ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 19194128384
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00004-of-00004.safetensors",
7
+ "mm_head.weight": "model-00004-of-00004.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
19
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
29
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
30
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
31
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
32
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
33
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
34
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
35
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
36
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
38
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
41
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
42
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
43
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
49
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
50
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
53
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
55
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
62
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
65
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
67
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
74
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
77
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
79
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
86
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
88
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
89
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
90
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
91
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
92
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
98
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
100
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
101
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
102
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
103
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
105
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
110
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
112
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
113
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
114
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
115
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
116
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
118
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
119
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
122
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
124
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
125
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
126
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
127
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
128
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
129
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
130
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
131
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
132
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
133
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
134
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
135
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
136
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
137
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
138
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
139
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
140
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
141
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
142
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
143
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
144
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
145
+ "model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
146
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
147
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
148
+ "model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
149
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
150
+ "model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
151
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
152
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
153
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
154
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
155
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
156
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
157
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
158
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
159
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
160
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
161
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
162
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
163
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
164
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
165
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
166
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
167
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
168
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
169
+ "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
170
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
171
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
172
+ "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
173
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
174
+ "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
175
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
176
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
177
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
178
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
179
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
180
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
181
+ "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
182
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
183
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
184
+ "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
185
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
186
+ "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
187
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
190
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
193
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
194
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
196
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
197
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
198
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
199
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
200
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
204
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
206
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
209
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
210
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
211
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
213
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
214
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
215
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
217
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
218
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
220
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
221
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
222
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
223
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
224
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
226
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
228
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
230
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
231
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
233
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
235
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
238
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
240
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
241
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
242
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
244
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
245
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
246
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
247
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
252
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
254
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
256
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
257
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
258
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
259
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
261
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
262
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
263
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
264
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
265
+ "model.layers.28.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
266
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
267
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
268
+ "model.layers.28.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
269
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
270
+ "model.layers.28.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
271
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
272
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
273
+ "model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
274
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
275
+ "model.layers.29.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
276
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
277
+ "model.layers.29.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
278
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
279
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
280
+ "model.layers.29.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
281
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
282
+ "model.layers.29.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
283
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
284
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
285
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
286
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
287
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
288
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
289
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
290
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
291
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
292
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
293
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
294
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
295
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
296
+ "model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
297
+ "model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
298
+ "model.layers.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
299
+ "model.layers.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
300
+ "model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
301
+ "model.layers.30.self_attn.k_proj.bias": "model-00004-of-00004.safetensors",
302
+ "model.layers.30.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
303
+ "model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
304
+ "model.layers.30.self_attn.q_proj.bias": "model-00004-of-00004.safetensors",
305
+ "model.layers.30.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
306
+ "model.layers.30.self_attn.v_proj.bias": "model-00004-of-00004.safetensors",
307
+ "model.layers.30.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
308
+ "model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
309
+ "model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
310
+ "model.layers.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
311
+ "model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
312
+ "model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
313
+ "model.layers.31.self_attn.k_proj.bias": "model-00004-of-00004.safetensors",
314
+ "model.layers.31.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
315
+ "model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
316
+ "model.layers.31.self_attn.q_proj.bias": "model-00004-of-00004.safetensors",
317
+ "model.layers.31.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
318
+ "model.layers.31.self_attn.v_proj.bias": "model-00004-of-00004.safetensors",
319
+ "model.layers.31.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
320
+ "model.layers.32.input_layernorm.weight": "model-00004-of-00004.safetensors",
321
+ "model.layers.32.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
322
+ "model.layers.32.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
323
+ "model.layers.32.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
324
+ "model.layers.32.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
325
+ "model.layers.32.self_attn.k_proj.bias": "model-00004-of-00004.safetensors",
326
+ "model.layers.32.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
327
+ "model.layers.32.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
328
+ "model.layers.32.self_attn.q_proj.bias": "model-00004-of-00004.safetensors",
329
+ "model.layers.32.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
330
+ "model.layers.32.self_attn.v_proj.bias": "model-00004-of-00004.safetensors",
331
+ "model.layers.32.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
332
+ "model.layers.33.input_layernorm.weight": "model-00004-of-00004.safetensors",
333
+ "model.layers.33.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
334
+ "model.layers.33.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
335
+ "model.layers.33.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
336
+ "model.layers.33.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
337
+ "model.layers.33.self_attn.k_proj.bias": "model-00004-of-00004.safetensors",
338
+ "model.layers.33.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
339
+ "model.layers.33.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
340
+ "model.layers.33.self_attn.q_proj.bias": "model-00004-of-00004.safetensors",
341
+ "model.layers.33.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
342
+ "model.layers.33.self_attn.v_proj.bias": "model-00004-of-00004.safetensors",
343
+ "model.layers.33.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
344
+ "model.layers.34.input_layernorm.weight": "model-00004-of-00004.safetensors",
345
+ "model.layers.34.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
346
+ "model.layers.34.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
347
+ "model.layers.34.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
348
+ "model.layers.34.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
349
+ "model.layers.34.self_attn.k_proj.bias": "model-00004-of-00004.safetensors",
350
+ "model.layers.34.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
351
+ "model.layers.34.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
352
+ "model.layers.34.self_attn.q_proj.bias": "model-00004-of-00004.safetensors",
353
+ "model.layers.34.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
354
+ "model.layers.34.self_attn.v_proj.bias": "model-00004-of-00004.safetensors",
355
+ "model.layers.34.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
356
+ "model.layers.35.input_layernorm.weight": "model-00004-of-00004.safetensors",
357
+ "model.layers.35.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
358
+ "model.layers.35.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
359
+ "model.layers.35.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
360
+ "model.layers.35.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
361
+ "model.layers.35.self_attn.k_proj.bias": "model-00004-of-00004.safetensors",
362
+ "model.layers.35.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
363
+ "model.layers.35.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
364
+ "model.layers.35.self_attn.q_proj.bias": "model-00004-of-00004.safetensors",
365
+ "model.layers.35.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
366
+ "model.layers.35.self_attn.v_proj.bias": "model-00004-of-00004.safetensors",
367
+ "model.layers.35.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
368
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
369
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
370
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
371
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
372
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
373
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
374
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
375
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
376
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
377
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
378
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
379
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
380
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
381
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
382
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
383
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
384
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
385
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
386
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
387
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
388
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
389
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
390
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
391
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
392
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
393
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
394
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
395
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
396
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
397
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
398
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
399
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
400
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
401
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
402
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
403
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
404
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
405
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
406
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
407
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
408
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
409
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
410
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
411
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
412
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
413
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
414
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
415
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
416
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
417
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
418
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
419
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
420
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
421
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
422
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
423
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
424
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
425
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
426
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
427
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
428
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
429
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
430
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
431
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
432
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
433
+ "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
434
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
435
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
436
+ "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
437
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
438
+ "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
439
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
440
+ "model.lm_embed_tokens.weight": "model-00001-of-00004.safetensors",
441
+ "model.lm_norm.weight": "model-00004-of-00004.safetensors",
442
+ "model.mm_embed_tokens.weight": "model-00001-of-00004.safetensors",
443
+ "model.mm_norm.weight": "model-00004-of-00004.safetensors"
444
+ }
445
+ }
modeling_siglip_flux.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from typing import Any, Callable, Dict, Tuple, List, Optional, Union
5
+ from diffusers import FluxTransformer2DModel
6
+ from diffusers.configuration_utils import register_to_config
7
+ from diffusers.utils import logging, USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
8
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
9
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
12
+
13
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
14
+
15
+
16
+ def drop_token(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
17
+ if drop_prob == 0. or not training:
18
+ return x
19
+ keep_prob = 1 - drop_prob
20
+ shape = (x.shape[0], x.shape[1], 1)
21
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
22
+ if keep_prob > 0.0 and scale_by_keep:
23
+ random_tensor.div_(keep_prob)
24
+ return x * random_tensor
25
+
26
+
27
+ class FluxTransformer2DModelWithSigLIP(FluxTransformer2DModel):
28
+ @register_to_config
29
+ def __init__(
30
+ self,
31
+ patch_size: int = 1,
32
+ in_channels: int = 64,
33
+ out_channels: Optional[int] = None,
34
+ num_layers: int = 19,
35
+ num_single_layers: int = 38,
36
+ attention_head_dim: int = 128,
37
+ num_attention_heads: int = 24,
38
+ joint_attention_dim: int = 4096,
39
+ pooled_projection_dim: int = 768,
40
+ guidance_embeds: bool = False,
41
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
42
+ siglip_channels: Optional[int] = None,
43
+ drop_token_prob: float = 0.,
44
+ ):
45
+ super().__init__(
46
+ patch_size=patch_size,
47
+ in_channels=in_channels,
48
+ out_channels=out_channels,
49
+ num_layers=num_layers,
50
+ num_single_layers=num_single_layers,
51
+ attention_head_dim=attention_head_dim,
52
+ num_attention_heads=num_attention_heads,
53
+ joint_attention_dim=joint_attention_dim,
54
+ pooled_projection_dim=pooled_projection_dim,
55
+ guidance_embeds=guidance_embeds,
56
+ axes_dims_rope=axes_dims_rope,
57
+ )
58
+ self.drop_token_prob = drop_token_prob
59
+ if siglip_channels is not None:
60
+ self.init_siglip_embed(siglip_channels)
61
+
62
+ def init_siglip_embed(self, siglip_channels):
63
+ self.siglip_embed = torch.nn.Linear(siglip_channels, self.inner_dim, bias=False)
64
+ torch.nn.init.zeros_(self.siglip_embed.weight)
65
+
66
+ def forward(
67
+ self,
68
+ hidden_states: torch.Tensor,
69
+ encoder_hidden_states: torch.Tensor = None,
70
+ pooled_projections: torch.Tensor = None,
71
+ timestep: torch.LongTensor = None,
72
+ img_ids: torch.Tensor = None,
73
+ txt_ids: torch.Tensor = None,
74
+ guidance: torch.Tensor = None,
75
+ siglip_tensor: Optional[torch.Tensor] = None,
76
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
77
+ controlnet_block_samples=None,
78
+ controlnet_single_block_samples=None,
79
+ return_dict: bool = True,
80
+ controlnet_blocks_repeat: bool = False,
81
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
82
+ """
83
+ The [`FluxTransformer2DModel`] forward method.
84
+
85
+ Args:
86
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
87
+ Input `hidden_states`.
88
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
89
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
90
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
91
+ from the embeddings of input conditions.
92
+ timestep ( `torch.LongTensor`):
93
+ Used to indicate denoising step.
94
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
95
+ A list of tensors that if specified are added to the residuals of transformer blocks.
96
+ joint_attention_kwargs (`dict`, *optional*):
97
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
98
+ `self.processor` in
99
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
100
+ return_dict (`bool`, *optional*, defaults to `True`):
101
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
102
+ tuple.
103
+
104
+ Returns:
105
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
106
+ `tuple` where the first element is the sample tensor.
107
+ """
108
+ if joint_attention_kwargs is not None:
109
+ joint_attention_kwargs = joint_attention_kwargs.copy()
110
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
111
+ else:
112
+ lora_scale = 1.0
113
+
114
+ if USE_PEFT_BACKEND:
115
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
116
+ scale_lora_layers(self, lora_scale)
117
+ else:
118
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
119
+ logger.warning(
120
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
121
+ )
122
+
123
+ hidden_states = self.x_embedder(hidden_states)
124
+
125
+ timestep = timestep.to(hidden_states.dtype) * 1000
126
+ if guidance is not None:
127
+ guidance = guidance.to(hidden_states.dtype) * 1000
128
+ else:
129
+ guidance = None
130
+
131
+ temb = (
132
+ self.time_text_embed(timestep, pooled_projections)
133
+ if guidance is None
134
+ else self.time_text_embed(timestep, guidance, pooled_projections)
135
+ )
136
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
137
+
138
+ if txt_ids.ndim == 3:
139
+ logger.warning(
140
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
141
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
142
+ )
143
+ txt_ids = txt_ids[0]
144
+ if img_ids.ndim == 3:
145
+ logger.warning(
146
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
147
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
148
+ )
149
+ img_ids = img_ids[0]
150
+
151
+ ids = torch.cat((txt_ids, img_ids), dim=0)
152
+ image_rotary_emb = self.pos_embed(ids)
153
+
154
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
155
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
156
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
157
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
158
+
159
+ for index_block, block in enumerate(self.transformer_blocks):
160
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
161
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
162
+ block,
163
+ hidden_states,
164
+ encoder_hidden_states,
165
+ temb,
166
+ image_rotary_emb,
167
+ )
168
+
169
+ else:
170
+ encoder_hidden_states, hidden_states = block(
171
+ hidden_states=hidden_states,
172
+ encoder_hidden_states=encoder_hidden_states,
173
+ temb=temb,
174
+ image_rotary_emb=image_rotary_emb,
175
+ joint_attention_kwargs=joint_attention_kwargs,
176
+ )
177
+
178
+ # controlnet residual
179
+ if controlnet_block_samples is not None:
180
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
181
+ interval_control = int(np.ceil(interval_control))
182
+ # For Xlabs ControlNet.
183
+ if controlnet_blocks_repeat:
184
+ hidden_states = (
185
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
186
+ )
187
+ else:
188
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
189
+
190
+ if siglip_tensor is not None:
191
+ siglip_tensor = drop_token(siglip_tensor, self.drop_token_prob, training=self.training)
192
+ hidden_states = hidden_states + self.siglip_embed(siglip_tensor)
193
+
194
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
195
+
196
+ for index_block, block in enumerate(self.single_transformer_blocks):
197
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
198
+ hidden_states = self._gradient_checkpointing_func(
199
+ block,
200
+ hidden_states,
201
+ temb,
202
+ image_rotary_emb,
203
+ )
204
+
205
+ else:
206
+ hidden_states = block(
207
+ hidden_states=hidden_states,
208
+ temb=temb,
209
+ image_rotary_emb=image_rotary_emb,
210
+ joint_attention_kwargs=joint_attention_kwargs,
211
+ )
212
+
213
+ # controlnet residual
214
+ if controlnet_single_block_samples is not None:
215
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
216
+ interval_control = int(np.ceil(interval_control))
217
+ hidden_states[:, encoder_hidden_states.shape[1]:, ...] = (
218
+ hidden_states[:, encoder_hidden_states.shape[1]:, ...]
219
+ + controlnet_single_block_samples[index_block // interval_control]
220
+ )
221
+
222
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:, ...]
223
+
224
+ hidden_states = self.norm_out(hidden_states, temb)
225
+ output = self.proj_out(hidden_states)
226
+
227
+ if USE_PEFT_BACKEND:
228
+ # remove `lora_scale` from each PEFT layer
229
+ unscale_lora_layers(self, lora_scale)
230
+
231
+ if not return_dict:
232
+ return (output,)
233
+
234
+ return Transformer2DModelOutput(sample=output)
235
+
236
+
237
+ def teacache_forward(
238
+ self,
239
+ hidden_states: torch.Tensor,
240
+ encoder_hidden_states: torch.Tensor = None,
241
+ pooled_projections: torch.Tensor = None,
242
+ timestep: torch.LongTensor = None,
243
+ img_ids: torch.Tensor = None,
244
+ txt_ids: torch.Tensor = None,
245
+ guidance: torch.Tensor = None,
246
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
247
+ controlnet_block_samples=None,
248
+ controlnet_single_block_samples=None,
249
+ return_dict: bool = True,
250
+ controlnet_blocks_repeat: bool = False,
251
+ siglip_tensor: Optional[torch.Tensor] = None,
252
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
253
+ """
254
+ The [`FluxTransformer2DModel`] forward method.
255
+
256
+ Args:
257
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
258
+ Input `hidden_states`.
259
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
260
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
261
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
262
+ from the embeddings of input conditions.
263
+ timestep ( `torch.LongTensor`):
264
+ Used to indicate denoising step.
265
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
266
+ A list of tensors that if specified are added to the residuals of transformer blocks.
267
+ joint_attention_kwargs (`dict`, *optional*):
268
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
269
+ `self.processor` in
270
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
271
+ return_dict (`bool`, *optional*, defaults to `True`):
272
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
273
+ tuple.
274
+
275
+ Returns:
276
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
277
+ `tuple` where the first element is the sample tensor.
278
+ """
279
+ if joint_attention_kwargs is not None:
280
+ joint_attention_kwargs = joint_attention_kwargs.copy()
281
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
282
+ else:
283
+ lora_scale = 1.0
284
+
285
+ if USE_PEFT_BACKEND:
286
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
287
+ scale_lora_layers(self, lora_scale)
288
+ else:
289
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
290
+ logger.warning(
291
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
292
+ )
293
+
294
+ batch_size, seq_len, channels = hidden_states.shape
295
+ device, dtype = hidden_states.device, hidden_states.dtype
296
+ hidden_states = self.x_embedder(hidden_states)
297
+
298
+ timestep = timestep.to(hidden_states.dtype) * 1000
299
+ if guidance is not None:
300
+ guidance = guidance.to(hidden_states.dtype) * 1000
301
+ else:
302
+ guidance = None
303
+
304
+ temb = (
305
+ self.time_text_embed(timestep, pooled_projections)
306
+ if guidance is None
307
+ else self.time_text_embed(timestep, guidance, pooled_projections)
308
+ )
309
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
310
+
311
+ if txt_ids.ndim == 3:
312
+ logger.warning(
313
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
314
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
315
+ )
316
+ txt_ids = txt_ids[0]
317
+ if img_ids.ndim == 3:
318
+ logger.warning(
319
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
320
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
321
+ )
322
+ img_ids = img_ids[0]
323
+
324
+ ids = torch.cat((txt_ids, img_ids), dim=0)
325
+ image_rotary_emb = self.pos_embed(ids)
326
+
327
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
328
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
329
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
330
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
331
+
332
+ if self.enable_teacache:
333
+ inp = hidden_states.clone()
334
+ temb_ = temb.clone()
335
+ modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_)
336
+ if self.cnt == 0 or self.cnt == self.num_steps - 1:
337
+ should_calc = True
338
+ self.accumulated_rel_l1_distance = 0
339
+ else:
340
+ coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
341
+ rescale_func = np.poly1d(coefficients)
342
+ # rescale_func = Polynomial(coefficients.reverse())
343
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
344
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
345
+ should_calc = False
346
+ else:
347
+ should_calc = True
348
+ self.accumulated_rel_l1_distance = 0
349
+ self.previous_modulated_input = modulated_inp
350
+ self.cnt += 1
351
+ if self.cnt == self.num_steps:
352
+ self.cnt = 0
353
+
354
+ if self.enable_teacache:
355
+ if not should_calc:
356
+ hidden_states += self.previous_residual
357
+ else:
358
+ ori_hidden_states = hidden_states.clone()
359
+ for index_block, block in enumerate(self.transformer_blocks):
360
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
361
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
362
+ block,
363
+ hidden_states,
364
+ encoder_hidden_states,
365
+ temb,
366
+ image_rotary_emb,
367
+ )
368
+
369
+ else:
370
+ encoder_hidden_states, hidden_states = block(
371
+ hidden_states=hidden_states,
372
+ encoder_hidden_states=encoder_hidden_states,
373
+ temb=temb,
374
+ image_rotary_emb=image_rotary_emb,
375
+ joint_attention_kwargs=joint_attention_kwargs,
376
+ )
377
+
378
+ # controlnet residual
379
+ if controlnet_block_samples is not None:
380
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
381
+ interval_control = int(np.ceil(interval_control))
382
+ # For Xlabs ControlNet.
383
+ if controlnet_blocks_repeat:
384
+ hidden_states = (
385
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
386
+ )
387
+ else:
388
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
389
+
390
+ if siglip_tensor is not None:
391
+ siglip_tensor = drop_token(siglip_tensor, self.drop_token_prob, training=self.training)
392
+ hidden_states = hidden_states + self.siglip_embed(siglip_tensor)
393
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
394
+
395
+ for index_block, block in enumerate(self.single_transformer_blocks):
396
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
397
+ hidden_states = self._gradient_checkpointing_func(
398
+ block,
399
+ hidden_states,
400
+ temb,
401
+ image_rotary_emb,
402
+ )
403
+
404
+ else:
405
+ hidden_states = block(
406
+ hidden_states=hidden_states,
407
+ temb=temb,
408
+ image_rotary_emb=image_rotary_emb,
409
+ joint_attention_kwargs=joint_attention_kwargs,
410
+ )
411
+
412
+ # controlnet residual
413
+ if controlnet_single_block_samples is not None:
414
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
415
+ interval_control = int(np.ceil(interval_control))
416
+ hidden_states[:, encoder_hidden_states.shape[1]:, ...] = (
417
+ hidden_states[:, encoder_hidden_states.shape[1]:, ...]
418
+ + controlnet_single_block_samples[index_block // interval_control]
419
+ )
420
+
421
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:, ...]
422
+ self.previous_residual = hidden_states - ori_hidden_states
423
+ else:
424
+ for index_block, block in enumerate(self.transformer_blocks):
425
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
426
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
427
+ block,
428
+ hidden_states,
429
+ encoder_hidden_states,
430
+ temb,
431
+ image_rotary_emb,
432
+ )
433
+ else:
434
+ encoder_hidden_states, hidden_states = block(
435
+ hidden_states=hidden_states,
436
+ encoder_hidden_states=encoder_hidden_states,
437
+ temb=temb,
438
+ image_rotary_emb=image_rotary_emb,
439
+ joint_attention_kwargs=joint_attention_kwargs,
440
+ )
441
+
442
+ # controlnet residual
443
+ if controlnet_block_samples is not None:
444
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
445
+ interval_control = int(np.ceil(interval_control))
446
+ # For Xlabs ControlNet.
447
+ if controlnet_blocks_repeat:
448
+ hidden_states = (
449
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
450
+ )
451
+ else:
452
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
453
+ if siglip_tensor is not None:
454
+ siglip_tensor = drop_token(siglip_tensor, self.drop_token_prob, training=self.training)
455
+ hidden_states = hidden_states + self.siglip_embed(siglip_tensor)
456
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
457
+
458
+ for index_block, block in enumerate(self.single_transformer_blocks):
459
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
460
+ hidden_states = self._gradient_checkpointing_func(
461
+ block,
462
+ hidden_states,
463
+ temb,
464
+ image_rotary_emb,
465
+ )
466
+
467
+ else:
468
+ hidden_states = block(
469
+ hidden_states=hidden_states,
470
+ temb=temb,
471
+ image_rotary_emb=image_rotary_emb,
472
+ joint_attention_kwargs=joint_attention_kwargs,
473
+ )
474
+
475
+ # controlnet residual
476
+ if controlnet_single_block_samples is not None:
477
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
478
+ interval_control = int(np.ceil(interval_control))
479
+ hidden_states[:, encoder_hidden_states.shape[1]:, ...] = (
480
+ hidden_states[:, encoder_hidden_states.shape[1]:, ...]
481
+ + controlnet_single_block_samples[index_block // interval_control]
482
+ )
483
+
484
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:, ...]
485
+
486
+ hidden_states = self.norm_out(hidden_states, temb)
487
+ output = self.proj_out(hidden_states)
488
+
489
+ if USE_PEFT_BACKEND:
490
+ # remove `lora_scale` from each PEFT layer
491
+ unscale_lora_layers(self, lora_scale)
492
+
493
+ if not return_dict:
494
+ return (output,)
495
+
496
+ return Transformer2DModelOutput(sample=output)
497
+
498
+
499
+ class FluxPipelineWithSigLIP(FluxPipeline):
500
+
501
+ @torch.no_grad()
502
+ def __call__(
503
+ self,
504
+ siglip_tensor: torch.Tensor,
505
+ prompt: Union[str, List[str]] = None,
506
+ prompt_2: Optional[Union[str, List[str]]] = None,
507
+ negative_prompt: Union[str, List[str]] = None,
508
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
509
+ true_cfg_scale: float = 1.0,
510
+ true_cfg_scale_2: float = 1.0,
511
+ height: Optional[int] = None,
512
+ width: Optional[int] = None,
513
+ num_inference_steps: int = 28,
514
+ sigmas: Optional[List[float]] = None,
515
+ guidance_scale: float = 3.5,
516
+ num_images_per_prompt: Optional[int] = 1,
517
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
518
+ latents: Optional[torch.FloatTensor] = None,
519
+ prompt_embeds: Optional[torch.FloatTensor] = None,
520
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
521
+ ip_adapter_image: Optional[PipelineImageInput] = None,
522
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
523
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
524
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
525
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
526
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
527
+ output_type: Optional[str] = "pil",
528
+ return_dict: bool = True,
529
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
530
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
531
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
532
+ max_sequence_length: int = 512,
533
+ ):
534
+ r"""
535
+ Function invoked when calling the pipeline for generation.
536
+
537
+ Args:
538
+ prompt (`str` or `List[str]`, *optional*):
539
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
540
+ instead.
541
+ prompt_2 (`str` or `List[str]`, *optional*):
542
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
543
+ will be used instead.
544
+ negative_prompt (`str` or `List[str]`, *optional*):
545
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
546
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
547
+ not greater than `1`).
548
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
549
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
550
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
551
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
552
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
553
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
554
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
555
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
556
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
557
+ num_inference_steps (`int`, *optional*, defaults to 50):
558
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
559
+ expense of slower inference.
560
+ sigmas (`List[float]`, *optional*):
561
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
562
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
563
+ will be used.
564
+ guidance_scale (`float`, *optional*, defaults to 3.5):
565
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
566
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
567
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
568
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
569
+ usually at the expense of lower image quality.
570
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
571
+ The number of images to generate per prompt.
572
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
573
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
574
+ to make generation deterministic.
575
+ latents (`torch.FloatTensor`, *optional*):
576
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
577
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
578
+ tensor will ge generated by sampling using the supplied random `generator`.
579
+ prompt_embeds (`torch.FloatTensor`, *optional*):
580
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
581
+ provided, text embeddings will be generated from `prompt` input argument.
582
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
583
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
584
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
585
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
586
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
587
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
588
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
589
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
590
+ negative_ip_adapter_image:
591
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
592
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
593
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
594
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
595
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
596
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
597
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
598
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
599
+ argument.
600
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
601
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
602
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
603
+ input argument.
604
+ output_type (`str`, *optional*, defaults to `"pil"`):
605
+ The output format of the generate image. Choose between
606
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
607
+ return_dict (`bool`, *optional*, defaults to `True`):
608
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
609
+ joint_attention_kwargs (`dict`, *optional*):
610
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
611
+ `self.processor` in
612
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
613
+ callback_on_step_end (`Callable`, *optional*):
614
+ A function that calls at the end of each denoising steps during the inference. The function is called
615
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
616
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
617
+ `callback_on_step_end_tensor_inputs`.
618
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
619
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
620
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
621
+ `._callback_tensor_inputs` attribute of your pipeline class.
622
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
623
+
624
+ Examples:
625
+
626
+ Returns:
627
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
628
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
629
+ images.
630
+ """
631
+ assert true_cfg_scale == true_cfg_scale_2
632
+
633
+ height = height or self.default_sample_size * self.vae_scale_factor
634
+ width = width or self.default_sample_size * self.vae_scale_factor
635
+
636
+ # 1. Check inputs. Raise error if not correct
637
+ self.check_inputs(
638
+ prompt,
639
+ prompt_2,
640
+ height,
641
+ width,
642
+ negative_prompt=negative_prompt,
643
+ negative_prompt_2=negative_prompt_2,
644
+ prompt_embeds=prompt_embeds,
645
+ negative_prompt_embeds=negative_prompt_embeds,
646
+ pooled_prompt_embeds=pooled_prompt_embeds,
647
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
648
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
649
+ max_sequence_length=max_sequence_length,
650
+ )
651
+
652
+ self._guidance_scale = guidance_scale
653
+ self._joint_attention_kwargs = joint_attention_kwargs
654
+ self._current_timestep = None
655
+ self._interrupt = False
656
+
657
+ # 2. Define call parameters
658
+ if prompt is not None and isinstance(prompt, str):
659
+ batch_size = 1
660
+ elif prompt is not None and isinstance(prompt, list):
661
+ batch_size = len(prompt)
662
+ else:
663
+ batch_size = prompt_embeds.shape[0]
664
+
665
+ device = self._execution_device
666
+
667
+ lora_scale = (
668
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
669
+ )
670
+ has_neg_prompt = negative_prompt is not None or (
671
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
672
+ )
673
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
674
+ (
675
+ prompt_embeds,
676
+ pooled_prompt_embeds,
677
+ text_ids,
678
+ ) = self.encode_prompt(
679
+ prompt=prompt,
680
+ prompt_2=prompt_2,
681
+ prompt_embeds=prompt_embeds,
682
+ pooled_prompt_embeds=pooled_prompt_embeds,
683
+ device=device,
684
+ num_images_per_prompt=num_images_per_prompt,
685
+ max_sequence_length=max_sequence_length,
686
+ lora_scale=lora_scale,
687
+ )
688
+ assert do_true_cfg
689
+ (
690
+ negative_prompt_embeds,
691
+ negative_pooled_prompt_embeds,
692
+ _,
693
+ ) = self.encode_prompt(
694
+ prompt=negative_prompt,
695
+ prompt_2=negative_prompt_2,
696
+ prompt_embeds=negative_prompt_embeds,
697
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
698
+ device=device,
699
+ num_images_per_prompt=num_images_per_prompt,
700
+ max_sequence_length=max_sequence_length,
701
+ lora_scale=lora_scale,
702
+ )
703
+
704
+ # 4. Prepare latent variables
705
+ num_channels_latents = self.transformer.config.in_channels // 4
706
+ latents, latent_image_ids = self.prepare_latents(
707
+ batch_size * num_images_per_prompt,
708
+ num_channels_latents,
709
+ height,
710
+ width,
711
+ prompt_embeds.dtype,
712
+ device,
713
+ generator,
714
+ latents,
715
+ )
716
+
717
+ # 5. Prepare timesteps
718
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
719
+ image_seq_len = latents.shape[1]
720
+ mu = calculate_shift(
721
+ image_seq_len,
722
+ self.scheduler.config.get("base_image_seq_len", 256),
723
+ self.scheduler.config.get("max_image_seq_len", 4096),
724
+ self.scheduler.config.get("base_shift", 0.5),
725
+ self.scheduler.config.get("max_shift", 1.15),
726
+ )
727
+ timesteps, num_inference_steps = retrieve_timesteps(
728
+ self.scheduler,
729
+ num_inference_steps,
730
+ device,
731
+ sigmas=sigmas,
732
+ mu=mu,
733
+ )
734
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
735
+ self._num_timesteps = len(timesteps)
736
+
737
+ # handle guidance
738
+ if self.transformer.config.guidance_embeds:
739
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
740
+ guidance = guidance.expand(latents.shape[0] * 2)
741
+ else:
742
+ guidance = None
743
+
744
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
745
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
746
+ ):
747
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
748
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
749
+
750
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
751
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
752
+ ):
753
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
754
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
755
+
756
+ if self.joint_attention_kwargs is None:
757
+ self._joint_attention_kwargs = {}
758
+
759
+ image_embeds = None
760
+ negative_image_embeds = None
761
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
762
+ image_embeds = self.prepare_ip_adapter_image_embeds(
763
+ ip_adapter_image,
764
+ ip_adapter_image_embeds,
765
+ device,
766
+ batch_size * num_images_per_prompt,
767
+ )
768
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
769
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
770
+ negative_ip_adapter_image,
771
+ negative_ip_adapter_image_embeds,
772
+ device,
773
+ batch_size * num_images_per_prompt,
774
+ )
775
+
776
+ # 6. Denoising loop
777
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
778
+ for i, t in enumerate(timesteps):
779
+ if self.interrupt:
780
+ continue
781
+
782
+ self._current_timestep = t
783
+ if image_embeds is not None:
784
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
785
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
786
+ timestep = t.expand(latents.shape[0] * 2).to(latents.dtype)
787
+
788
+ batch_noise_pred = self.transformer(
789
+ hidden_states=torch.cat([latents, latents], dim=0),
790
+ timestep=timestep / 1000,
791
+ guidance=guidance,
792
+ pooled_projections=torch.cat([pooled_prompt_embeds, negative_pooled_prompt_embeds.expand_as(pooled_prompt_embeds)], dim=0),
793
+ encoder_hidden_states=torch.cat([prompt_embeds, negative_prompt_embeds.expand_as(prompt_embeds)], dim=0),
794
+ txt_ids=text_ids,
795
+ img_ids=latent_image_ids,
796
+ joint_attention_kwargs=self.joint_attention_kwargs,
797
+ siglip_tensor=torch.cat([siglip_tensor, torch.zeros_like(siglip_tensor)], dim=0),
798
+ return_dict=False,
799
+ )[0]
800
+ noise_pred, neg_noise_pred = batch_noise_pred.chunk(2)
801
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
802
+
803
+ # compute the previous noisy sample x_t -> x_t-1
804
+ latents_dtype = latents.dtype
805
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
806
+
807
+ if latents.dtype != latents_dtype:
808
+ if torch.backends.mps.is_available():
809
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
810
+ latents = latents.to(latents_dtype)
811
+
812
+ if callback_on_step_end is not None:
813
+ callback_kwargs = {}
814
+ for k in callback_on_step_end_tensor_inputs:
815
+ callback_kwargs[k] = locals()[k]
816
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
817
+
818
+ latents = callback_outputs.pop("latents", latents)
819
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
820
+
821
+ # call the callback, if provided
822
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
823
+ progress_bar.update()
824
+
825
+ self._current_timestep = None
826
+
827
+ if output_type == "latent":
828
+ image = latents
829
+ else:
830
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
831
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
832
+ image = self.vae.decode(latents, return_dict=False)[0]
833
+ image = self.image_processor.postprocess(image, output_type=output_type)
834
+
835
+ # Offload all models
836
+ self.maybe_free_model_hooks()
837
+
838
+ if not return_dict:
839
+ return (image,)
840
+
841
+ return FluxPipelineOutput(images=image)
modeling_siglip_tokenizer.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import einsum
7
+ from torchvision import transforms
8
+
9
+ from PIL import Image
10
+ from einops import rearrange
11
+
12
+ from .modeling_vit import create_siglip_vit
13
+
14
+
15
+ def create_anyres_preprocess(
16
+ short_size=384,
17
+ long_size=1152,
18
+ patch_size=16,
19
+ random_ratio=None,
20
+ min_short_size=128,
21
+ max_aspect_ratio=3.,
22
+ filtering=True
23
+ ):
24
+
25
+ def resize_and_filtering(pil_image):
26
+ pil_image = pil_image.convert('RGB')
27
+ width, height = pil_image.size
28
+ ss, ls = min(width, height), max(width, height)
29
+ aspect_ratio = ls / ss
30
+ if filtering and (ss < min_short_size or aspect_ratio > max_aspect_ratio):
31
+ return None
32
+ target_width, target_height = width, height
33
+ if random_ratio is not None:
34
+ log_ratio = torch.log(torch.tensor(random_ratio))
35
+ sqrt_ratio = torch.exp(0.5 * torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
36
+ target_width = int(round(target_width * sqrt_ratio))
37
+ target_height = int(round(target_height / sqrt_ratio))
38
+
39
+ ss = min(target_width, target_height)
40
+ if ss < short_size:
41
+ target_width = target_width * (short_size / ss)
42
+ target_height = target_height * (short_size / ss)
43
+
44
+ ls = max(target_width, target_height)
45
+ if ls > long_size:
46
+ target_width = target_width * (long_size / ls)
47
+ target_height = target_height * (long_size / ls)
48
+
49
+ target_width = int(round(target_width / patch_size)) * patch_size
50
+ target_height = int(round(target_height / patch_size)) * patch_size
51
+ pil_image = pil_image.resize((target_width, target_height), resample=Image.BICUBIC)
52
+
53
+ to_tensor = transforms.Compose([
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
56
+ ])
57
+ return to_tensor(pil_image)
58
+
59
+ transform = transforms.Lambda(resize_and_filtering)
60
+ return transform
61
+
62
+
63
+ class IBQ(nn.Module):
64
+ def __init__(self, n_e, e_dim, skip_quantization_prob=0.0, quantization_temp=2.0, beta=0.25, sane_index_shape=False, l2_norm=True):
65
+ super().__init__()
66
+ self.n_e = n_e
67
+ self.e_dim = e_dim
68
+ self.quantization_temp = quantization_temp
69
+ self.skip_quantization_prob = skip_quantization_prob
70
+ self.beta = beta
71
+ self.sane_index_shape = sane_index_shape
72
+ self.l2_norm = l2_norm
73
+
74
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
75
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
76
+ if self.l2_norm:
77
+ self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
78
+
79
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False, **kwargs):
80
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
81
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
82
+ assert return_logits == False, "Only for interface compatible with Gumbel"
83
+ # reshape z -> (batch, height, width, channel) and flatten
84
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
85
+ assert z.shape[-1] == self.e_dim
86
+ z_flattened = z.view(-1, self.e_dim)
87
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
88
+
89
+ if self.l2_norm:
90
+ z = F.normalize(z, p=2, dim=-1)
91
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
92
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
93
+ else:
94
+ embedding = self.embedding.weight
95
+
96
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
97
+ torch.sum(embedding**2, dim=1) - 2 * \
98
+ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
99
+
100
+ if self.training:
101
+ logits = -d / self.quantization_temp
102
+ soft_one_hot = F.softmax(logits, dim=1)
103
+ min_encoding_indices = soft_one_hot.max(1, keepdim=True)[1]
104
+ hard_one_hot = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(1, min_encoding_indices, 1.0)
105
+ one_hot = hard_one_hot - soft_one_hot.detach() + soft_one_hot
106
+
107
+ z_q = einsum('b n, n d -> b d', one_hot, self.embedding.weight).view(z.shape)
108
+ z_q_2 = einsum('b n, n d -> b d', hard_one_hot, self.embedding.weight).view(z.shape)
109
+
110
+ # compute loss for embedding
111
+ commit_loss = torch.mean((z_q - z) ** 2) + torch.mean((z_q_2.detach() - z) ** 2) + self.beta * \
112
+ torch.mean((z_q_2 - z.detach()) ** 2)
113
+ else:
114
+ min_encoding_indices = torch.argmin(d, dim=1)
115
+ z_q = embedding[min_encoding_indices].view(z.shape)
116
+ commit_loss = None
117
+
118
+ if self.training and self.skip_quantization_prob > 0.0:
119
+ z_q = torch.where(
120
+ torch.rand_like(z_q[:, 0:1, 0:1, 0:1]).expand_as(z_q) <= self.skip_quantization_prob,
121
+ z, z_q,
122
+ )
123
+
124
+ # reshape back to match original input shape
125
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
126
+
127
+ if self.sane_index_shape:
128
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
129
+
130
+ return (z_q, None, min_encoding_indices), commit_loss
131
+
132
+ def get_codebook_entry(self, indices, bhwc):
133
+ # shape specifying (batch, height, width, channel)
134
+ # get quantized latent vectors
135
+ z_q = self.embedding(indices)
136
+
137
+ if bhwc is not None:
138
+ z_q = z_q.view(bhwc)
139
+ # reshape back to match original input shape
140
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
141
+
142
+ return z_q
143
+
144
+
145
+ class ResidualBlock(nn.Module):
146
+ def __init__(self, channels, num_groups=32):
147
+ super().__init__()
148
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding='same')
149
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
150
+ self.activate = nn.GELU()
151
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding='same')
152
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
153
+
154
+ def forward(self, x):
155
+ res = x
156
+ x = self.norm1(x)
157
+ x = self.activate(x)
158
+ x = self.conv1(x)
159
+ x = self.norm2(x)
160
+ x = self.activate(x)
161
+ x = self.conv2(x)
162
+ return x + res
163
+
164
+
165
+ class VQConvProjector(nn.Module):
166
+ def __init__(
167
+ self,
168
+ z_channels=1536,
169
+ codebook_size=16384,
170
+ codebook_dim=2048,
171
+ conv_layers=2,
172
+ with_norm=True,
173
+ skip_quant_prob=0.1,
174
+ ):
175
+ super().__init__()
176
+ self.quant_conv = nn.Conv2d(z_channels, codebook_dim, 1)
177
+ self.quantize = IBQ(codebook_size, codebook_dim, skip_quant_prob, sane_index_shape=True)
178
+ self.post_quant_conv = nn.Conv2d(codebook_dim, z_channels, 1)
179
+ block = ResidualBlock
180
+ self.post_conv = nn.Sequential(*[block(z_channels) for _ in range(conv_layers)])
181
+
182
+ def forward(self, x, h, w):
183
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
184
+ z = self.quant_conv(x)
185
+ (z_q, _, _), codebook_loss = self.quantize(z)
186
+ z = self.post_quant_conv(z_q)
187
+ z = self.post_conv(z)
188
+ z = rearrange(z, 'b c h w -> b (h w) c')
189
+ return z, codebook_loss
190
+
191
+ def encode(self, x, h, w):
192
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
193
+ z = self.quant_conv(x)
194
+ (_, _, tokens), _ = self.quantize(z)
195
+ return tokens
196
+
197
+ def decode(self, tokens, bhwc):
198
+ z_q = self.quantize.get_codebook_entry(tokens, bhwc)
199
+ z = self.post_quant_conv(z_q)
200
+ z = self.post_conv(z)
201
+ return z
202
+
203
+
204
+ class SiglipTokenizer(nn.Module):
205
+ def __init__(
206
+ self,
207
+ siglip_name,
208
+ siglip_path,
209
+ projector_path,
210
+ z_channels=1536,
211
+ codebook_size=16384,
212
+ codebook_dim=2048,
213
+ with_norm=True
214
+ ):
215
+ super().__init__()
216
+ self.vit = create_siglip_vit(model_name=siglip_name, path=siglip_path)
217
+ self.vqproj = VQConvProjector(
218
+ z_channels=z_channels,
219
+ codebook_size=codebook_size,
220
+ codebook_dim=codebook_dim,
221
+ with_norm=with_norm
222
+ )
223
+ self.vqproj.load_state_dict(torch.load(projector_path, map_location='cpu'), strict=True)
224
+
225
+ def encode(self, x):
226
+ features, (h, w), _ = self.vit(x)
227
+ tokens = self.vqproj.encode(features, h, w)
228
+ return tokens
229
+
230
+ def decode(self, tokens, bhwc):
231
+ return self.vqproj.decode(tokens, bhwc)
modeling_vit.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ from typing import (
6
+ Callable, Dict, Final, List, Literal, Optional,
7
+ Sequence, Set, Tuple, Type, Union,
8
+ )
9
+
10
+ from torch.utils.checkpoint import checkpoint
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from timm.layers import (
16
+ DropPath, LayerType, Mlp, PatchDropout,
17
+ PatchEmbed, resample_abs_pos_embed,
18
+ )
19
+ from timm.models._manipulate import checkpoint_seq, named_apply
20
+
21
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
22
+
23
+
24
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
25
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
26
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
27
+ def norm_cdf(x):
28
+ # Computes standard normal cumulative distribution function
29
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
30
+
31
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
32
+ warnings.warn(
33
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
34
+ "The distribution of values may be incorrect.",
35
+ stacklevel=2,
36
+ )
37
+
38
+ with torch.no_grad():
39
+ # Values are generated by using a truncated uniform distribution and
40
+ # then using the inverse CDF for the normal distribution.
41
+ # Get upper and lower cdf values
42
+ l = norm_cdf((a - mean) / std) # noqa: E741
43
+ u = norm_cdf((b - mean) / std)
44
+
45
+ # Uniformly fill tensor with values from [l, u], then translate to
46
+ # [2l-1, 2u-1].
47
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
48
+
49
+ # Use inverse cdf transform for normal distribution to get truncated
50
+ # standard normal
51
+ tensor.erfinv_()
52
+
53
+ # Transform to proper mean, std
54
+ tensor.mul_(std * math.sqrt(2.0))
55
+ tensor.add_(mean)
56
+
57
+ # Clamp to ensure it's in the proper range
58
+ tensor.clamp_(min=a, max=b)
59
+ return tensor
60
+
61
+
62
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
63
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
64
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
65
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
66
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
67
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
68
+ with values outside :math:`[a, b]` redrawn until they are within
69
+ the bounds. The method used for generating the random values works
70
+ best when :math:`a \leq \text{mean} \leq b`.
71
+ Args:
72
+ tensor: an n-dimensional `torch.Tensor`
73
+ mean: the mean of the normal distribution
74
+ std: the standard deviation of the normal distribution
75
+ a: the minimum cutoff value
76
+ b: the maximum cutoff value
77
+ Examples:
78
+ >>> w = torch.empty(3, 5)
79
+ >>> nn.init.trunc_normal_(w)
80
+ """
81
+
82
+ with torch.no_grad():
83
+ dtype = tensor.dtype
84
+ tensor_fp32 = tensor.float()
85
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
86
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
87
+ tensor.copy_(tensor_dtype)
88
+
89
+
90
+ def init_weights(self):
91
+ if self.pos_embed is not None:
92
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
93
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
94
+
95
+
96
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
97
+ """ViT weight initialization, original timm impl (for reproducibility)"""
98
+ if isinstance(module, nn.Linear):
99
+ trunc_normal_(module.weight, std=0.02)
100
+ if module.bias is not None:
101
+ nn.init.zeros_(module.bias)
102
+ elif hasattr(module, "init_weights"):
103
+ module.init_weights()
104
+
105
+
106
+ class Attention(nn.Module):
107
+ fused_attn: Final[bool]
108
+
109
+ def __init__(
110
+ self,
111
+ dim: int,
112
+ num_heads: int = 8,
113
+ qkv_bias: bool = False,
114
+ qk_norm: bool = False,
115
+ attn_drop: float = 0.0,
116
+ proj_drop: float = 0.0,
117
+ norm_layer: nn.Module = nn.LayerNorm,
118
+ ) -> None:
119
+ super().__init__()
120
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
121
+ self.num_heads = num_heads
122
+ self.head_dim = dim // num_heads
123
+ self.scale = self.head_dim**-0.5
124
+ # self.fused_attn = use_fused_attn()
125
+ self.fused_attn = True
126
+
127
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
128
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
129
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
133
+
134
+ def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
135
+ B, N, C = x.shape
136
+ qkv = (
137
+ self.qkv(x)
138
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
139
+ .permute(2, 0, 3, 1, 4)
140
+ )
141
+ q, k, v = qkv.unbind(0)
142
+ q, k = self.q_norm(q), self.k_norm(k)
143
+
144
+ if cu_slens is not None:
145
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
146
+ k = k.permute(0, 2, 1, 3)
147
+ v = v.permute(0, 2, 1, 3)
148
+ max_seqlen = torch.max(cu_slens[1:] - cu_slens[:-1]).item()
149
+ x = flash_attn_varlen_func(
150
+ q.squeeze(0),
151
+ k.squeeze(0),
152
+ v.squeeze(0),
153
+ cu_seqlens_q=cu_slens,
154
+ cu_seqlens_k=cu_slens,
155
+ max_seqlen_q=max_seqlen,
156
+ max_seqlen_k=max_seqlen,
157
+ softmax_scale=self.scale,
158
+ causal=False,
159
+ )
160
+
161
+ x = x.reshape(B, N, -1)
162
+ x = self.proj(x)
163
+ x = self.proj_drop(x)
164
+
165
+ else:
166
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
167
+ k = k.permute(0, 2, 1, 3)
168
+ v = v.permute(0, 2, 1, 3)
169
+ x = flash_attn_func(q, k, v, softmax_scale=self.scale) # -> b, n, h, c
170
+
171
+ x = x.reshape(B, N, -1)
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+
177
+ class LayerScale(nn.Module):
178
+ def __init__(
179
+ self,
180
+ dim: int,
181
+ init_values: float = 1e-5,
182
+ inplace: bool = False,
183
+ ) -> None:
184
+ super().__init__()
185
+ self.inplace = inplace
186
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
187
+
188
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
189
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
190
+
191
+
192
+ class Block(nn.Module):
193
+ def __init__(
194
+ self,
195
+ dim: int,
196
+ num_heads: int,
197
+ mlp_ratio: float = 4.0,
198
+ qkv_bias: bool = False,
199
+ qk_norm: bool = False,
200
+ proj_drop: float = 0.0,
201
+ attn_drop: float = 0.0,
202
+ init_values: Optional[float] = None,
203
+ drop_path: float = 0.0,
204
+ act_layer: nn.Module = nn.GELU,
205
+ norm_layer: nn.Module = nn.LayerNorm,
206
+ mlp_layer: nn.Module = Mlp,
207
+ ) -> None:
208
+ super().__init__()
209
+ self.norm1 = norm_layer(dim)
210
+ self.attn = Attention(
211
+ dim,
212
+ num_heads=num_heads,
213
+ qkv_bias=qkv_bias,
214
+ qk_norm=qk_norm,
215
+ attn_drop=attn_drop,
216
+ proj_drop=proj_drop,
217
+ norm_layer=norm_layer,
218
+ )
219
+ self.ls1 = (
220
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
221
+ )
222
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
223
+
224
+ self.norm2 = norm_layer(dim)
225
+ self.mlp = mlp_layer(
226
+ in_features=dim,
227
+ hidden_features=int(dim * mlp_ratio),
228
+ act_layer=act_layer,
229
+ drop=proj_drop,
230
+ )
231
+ self.ls2 = (
232
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
233
+ )
234
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
235
+
236
+ def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
237
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_slens=cu_slens)))
238
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
239
+ return x
240
+
241
+
242
+ class VisionTransformer(nn.Module):
243
+ """Vision Transformer
244
+
245
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
246
+ - https://arxiv.org/abs/2010.11929
247
+ """
248
+
249
+ dynamic_img_size: Final[bool]
250
+
251
+ def __init__(
252
+ self,
253
+ img_size: Union[int, Tuple[int, int]] = 224,
254
+ patch_size: Union[int, Tuple[int, int]] = 16,
255
+ in_chans: int = 3,
256
+ num_classes: int = 1000,
257
+ global_pool: Literal["", "avg", "token", "map"] = "token",
258
+ embed_dim: int = 768,
259
+ depth: int = 12,
260
+ num_heads: int = 12,
261
+ mlp_ratio: float = 4.0,
262
+ qkv_bias: bool = True,
263
+ qk_norm: bool = False,
264
+ init_values: Optional[float] = None,
265
+ class_token: bool = True,
266
+ no_embed_class: bool = False,
267
+ reg_tokens: int = 0,
268
+ pre_norm: bool = False,
269
+ fc_norm: Optional[bool] = None,
270
+ dynamic_img_size: bool = False,
271
+ dynamic_img_pad: bool = False,
272
+ drop_rate: float = 0.0,
273
+ pos_drop_rate: float = 0.0,
274
+ patch_drop_rate: float = 0.0,
275
+ proj_drop_rate: float = 0.0,
276
+ attn_drop_rate: float = 0.0,
277
+ drop_path_rate: float = 0.0,
278
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
279
+ embed_layer: Callable = PatchEmbed,
280
+ norm_layer: Optional[LayerType] = None,
281
+ act_layer: Optional[LayerType] = None,
282
+ strict_img_size: bool = False,
283
+ block_fn: Type[nn.Module] = Block,
284
+ mlp_layer: Type[nn.Module] = Mlp,
285
+ ignore_head: bool = False,
286
+ ) -> None:
287
+ """
288
+ Args:
289
+ img_size: Input image size.
290
+ patch_size: Patch size.
291
+ in_chans: Number of image input channels.
292
+ num_classes: Mumber of classes for classification head.
293
+ global_pool: Type of global pooling for final sequence (default: 'token').
294
+ embed_dim: Transformer embedding dimension.
295
+ depth: Depth of transformer.
296
+ num_heads: Number of attention heads.
297
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
298
+ qkv_bias: Enable bias for qkv projections if True.
299
+ init_values: Layer-scale init values (layer-scale enabled if not None).
300
+ class_token: Use class token.
301
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
302
+ reg_tokens: Number of register tokens.
303
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
304
+ drop_rate: Head dropout rate.
305
+ pos_drop_rate: Position embedding dropout rate.
306
+ attn_drop_rate: Attention dropout rate.
307
+ drop_path_rate: Stochastic depth rate.
308
+ weight_init: Weight initialization scheme.
309
+ embed_layer: Patch embedding layer.
310
+ norm_layer: Normalization layer.
311
+ act_layer: MLP activation layer.
312
+ block_fn: Transformer block layer.
313
+ """
314
+ super().__init__()
315
+ assert global_pool in ("", "avg", "token", "map")
316
+ assert class_token or global_pool != "token"
317
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
318
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
319
+ # act_layer = get_act_layer(act_layer) or nn.GELU
320
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
321
+ act_layer = nn.GELU
322
+
323
+ self.num_classes = num_classes
324
+ self.global_pool = global_pool
325
+ self.num_features = self.embed_dim = (
326
+ embed_dim # num_features for consistency with other models
327
+ )
328
+ self.num_prefix_tokens = 1 if class_token else 0
329
+ self.num_prefix_tokens += reg_tokens
330
+ self.num_reg_tokens = reg_tokens
331
+ self.has_class_token = class_token
332
+ self.no_embed_class = (
333
+ no_embed_class # don't embed prefix positions (includes reg)
334
+ )
335
+ self.dynamic_img_size = dynamic_img_size
336
+ self.grad_checkpointing = False
337
+ self.ignore_head = ignore_head
338
+
339
+ embed_args = {}
340
+ if dynamic_img_size:
341
+ # flatten deferred until after pos embed
342
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
343
+ self.patch_embed = embed_layer(
344
+ img_size=img_size,
345
+ patch_size=patch_size,
346
+ in_chans=in_chans,
347
+ embed_dim=embed_dim,
348
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
349
+ dynamic_img_pad=dynamic_img_pad,
350
+ strict_img_size=strict_img_size,
351
+ **embed_args,
352
+ )
353
+ num_patches = self.patch_embed.num_patches
354
+
355
+ self.cls_token = (
356
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
357
+ )
358
+ self.reg_token = (
359
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
360
+ )
361
+ embed_len = (
362
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
363
+ )
364
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
365
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
366
+ if patch_drop_rate > 0:
367
+ self.patch_drop = PatchDropout(
368
+ patch_drop_rate,
369
+ num_prefix_tokens=self.num_prefix_tokens,
370
+ )
371
+ else:
372
+ self.patch_drop = nn.Identity()
373
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
374
+
375
+ dpr = [
376
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
377
+ ] # stochastic depth decay rule
378
+ self.blocks = nn.Sequential(
379
+ *[
380
+ block_fn(
381
+ dim=embed_dim,
382
+ num_heads=num_heads,
383
+ mlp_ratio=mlp_ratio,
384
+ qkv_bias=qkv_bias,
385
+ qk_norm=qk_norm,
386
+ init_values=init_values,
387
+ proj_drop=proj_drop_rate,
388
+ attn_drop=attn_drop_rate,
389
+ drop_path=dpr[i],
390
+ norm_layer=norm_layer,
391
+ act_layer=act_layer,
392
+ mlp_layer=mlp_layer,
393
+ )
394
+ for i in range(depth)
395
+ ]
396
+ )
397
+
398
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
399
+ assert mode in ("jax", "jax_nlhb", "moco", "")
400
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
401
+ trunc_normal_(self.pos_embed, std=0.02)
402
+ if self.cls_token is not None:
403
+ nn.init.normal_(self.cls_token, std=1e-6)
404
+ named_apply(init_weights_vit_timm, self)
405
+
406
+ @torch.jit.ignore
407
+ def no_weight_decay(self) -> Set:
408
+ return {"pos_embed", "cls_token", "dist_token"}
409
+
410
+ @torch.jit.ignore
411
+ def group_matcher(self, coarse: bool = False) -> Dict:
412
+ return dict(
413
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
414
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
415
+ )
416
+
417
+ @torch.jit.ignore
418
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
419
+ self.grad_checkpointing = enable
420
+
421
+ @torch.jit.ignore
422
+ def get_classifier(self) -> nn.Module:
423
+ return self.head
424
+
425
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
426
+ self.num_classes = num_classes
427
+ if global_pool is not None:
428
+ assert global_pool in ("", "avg", "token", "map")
429
+ if global_pool == "map" and self.attn_pool is None:
430
+ assert (
431
+ False
432
+ ), "Cannot currently add attention pooling in reset_classifier()."
433
+ elif global_pool != "map " and self.attn_pool is not None:
434
+ self.attn_pool = None # remove attention pooling
435
+ self.global_pool = global_pool
436
+ self.head = (
437
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
438
+ )
439
+
440
+ def rescale_positional_embedding(self, out_size):
441
+ h, w = out_size
442
+ pos_embed_shape = int((self.pos_embed.shape[1]) ** 0.5)
443
+ if (h, w) == (pos_embed_shape, pos_embed_shape):
444
+ return self.pos_embed
445
+ rescaled_positional_embedding = \
446
+ self.pos_embed.new_zeros(1, h*w, self.pos_embed.shape[2])
447
+ pe_2d = self.pos_embed[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
448
+ pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
449
+ rescaled_positional_embedding[0] = pe_2d.T.contiguous()
450
+ return rescaled_positional_embedding
451
+
452
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
453
+ if self.dynamic_img_size:
454
+ B, H, W, C = x.shape
455
+ pos_embed = resample_abs_pos_embed(
456
+ self.pos_embed,
457
+ (H, W),
458
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
459
+ )
460
+ x = x.view(B, -1, C)
461
+ else:
462
+ pos_embed = self.pos_embed
463
+
464
+ to_cat = []
465
+ if self.cls_token is not None:
466
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
467
+ if self.reg_token is not None:
468
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
469
+
470
+ if self.no_embed_class:
471
+ # deit-3, updated JAX (big vision)
472
+ # position embedding does not overlap with class token, add then concat
473
+ x = x + pos_embed
474
+ if to_cat:
475
+ x = torch.cat(to_cat + [x], dim=1)
476
+ else:
477
+ # original timm, JAX, and deit vit impl
478
+ # pos_embed has entry for class token, concat then add
479
+ if to_cat:
480
+ x = torch.cat(to_cat + [x], dim=1)
481
+ x = x + pos_embed
482
+
483
+ return self.pos_drop(x)
484
+
485
+ def _intermediate_layers(
486
+ self,
487
+ x: torch.Tensor,
488
+ n: Union[int, Sequence] = 1,
489
+ ) -> List[torch.Tensor]:
490
+ outputs, num_blocks = [], len(self.blocks)
491
+ take_indices = set(
492
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
493
+ )
494
+
495
+ # forward pass
496
+ x = self.patch_embed(x)
497
+ x = self._pos_embed(x)
498
+ x = self.patch_drop(x)
499
+ x = self.norm_pre(x)
500
+ for i, blk in enumerate(self.blocks):
501
+ x = blk(x)
502
+ if i in take_indices:
503
+ outputs.append(x)
504
+
505
+ return outputs
506
+
507
+ def get_intermediate_layers(
508
+ self,
509
+ x: torch.Tensor,
510
+ n: Union[int, Sequence] = 1,
511
+ reshape: bool = False,
512
+ return_prefix_tokens: bool = False,
513
+ norm: bool = False,
514
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
515
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
516
+ Inspired by DINO / DINOv2 interface
517
+ """
518
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
519
+ outputs = self._intermediate_layers(x, n)
520
+ if norm:
521
+ outputs = [self.norm(out) for out in outputs]
522
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
523
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
524
+
525
+ if reshape:
526
+ grid_size = self.patch_embed.grid_size
527
+ outputs = [
528
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
529
+ .permute(0, 3, 1, 2)
530
+ .contiguous()
531
+ for out in outputs
532
+ ]
533
+
534
+ if return_prefix_tokens:
535
+ return tuple(zip(outputs, prefix_tokens))
536
+ return tuple(outputs)
537
+
538
+ def forward_features_list(self, x_list):
539
+ x_all = []
540
+ image_sizes = []
541
+ for x in x_list:
542
+ bs, _, h, w = x.shape
543
+
544
+ # fix patch size=14 in datasets
545
+ pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0]
546
+ pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1]
547
+ x = F.pad(x, (0, pad_w, 0, pad_h))
548
+
549
+ bs, _, h, w = x.shape
550
+
551
+ h = h // self.patch_embed.patch_size[0]
552
+ w = w // self.patch_embed.patch_size[1]
553
+
554
+ x = self.patch_embed(x)
555
+ x = x + self.rescale_positional_embedding(out_size=(h, w))
556
+ x = self.patch_drop(x)
557
+ x = self.norm_pre(x)
558
+ x_all.append(x)
559
+ image_sizes.append((h, w))
560
+
561
+ slen = [xi.size(1) for xi in x_all]
562
+ x = torch.cat(x_all, dim=1)
563
+
564
+ cu_indices = [0, ]
565
+ for i in slen:
566
+ cu_indices.append(cu_indices[-1] + i)
567
+
568
+ cu_slens = torch.tensor(cu_indices, dtype=torch.int32).to(x.device)
569
+ for idx, blk in enumerate(self.blocks):
570
+ if self.grad_checkpointing and not torch.jit.is_scripting():
571
+ x = checkpoint(blk, x, cu_slens, use_reentrant=True)
572
+ else:
573
+ x = blk(x, cu_slens=cu_slens)
574
+ feats = x.split(slen, dim=1) #[(1, slen, c)]
575
+ return feats, image_sizes
576
+
577
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
578
+ bs, _, h, w = x.shape
579
+ h = h // self.patch_embed.patch_size[0]
580
+ w = w // self.patch_embed.patch_size[1]
581
+
582
+ x = self.patch_embed(x)
583
+ # x = self._pos_embed(x)
584
+ x = x + self.rescale_positional_embedding(out_size=(h, w))
585
+ x = self.patch_drop(x)
586
+ x = self.norm_pre(x)
587
+ if self.grad_checkpointing and not torch.jit.is_scripting():
588
+ x = checkpoint_seq(self.blocks, x)
589
+ else:
590
+ x = self.blocks(x)
591
+ return x, (h, w)
592
+
593
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
594
+ x = self.norm(x)
595
+ if self.attn_pool is not None:
596
+ x = self.attn_pool(x)
597
+ elif self.global_pool == "avg":
598
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
599
+ elif self.global_pool:
600
+ x = x[:, 0] # class token
601
+ x = self.fc_norm(x)
602
+ x = self.head_drop(x)
603
+ return x if pre_logits else self.head(x)
604
+
605
+ def forward(self, x, cal_attn_pool=False):
606
+ if type(x) is list:
607
+ x, image_sizes = self.forward_features_list(x)
608
+ return x, image_sizes, None
609
+ else:
610
+ x, image_sizes = self.forward_features(x)
611
+ return x, image_sizes, None
612
+
613
+ @dataclass
614
+ class SigLIPVisionCfg:
615
+ width: int = 1152
616
+ layers: Union[Tuple[int, int, int, int], int] = 27
617
+ heads: int = 16
618
+ patch_size: int = 14
619
+ image_size: Union[Tuple[int, int], int] = 336
620
+ global_pool: str = "map"
621
+ mlp_ratio: float = 3.7362
622
+ class_token: bool = False
623
+ num_classes: int = 0
624
+ use_checkpoint: bool = False
625
+
626
+
627
+ SigLIP_MODEL_CONFIG = {
628
+ "siglip_so400m_patch16_384": {
629
+ "image_size": 384,
630
+ "patch_size": 16,
631
+ "width": 1152,
632
+ "layers": 27,
633
+ "heads": 16,
634
+ "mlp_ratio": 3.7362,
635
+ "global_pool": "map",
636
+ "use_checkpoint": False,
637
+ },
638
+ "siglip2_giant_patch16_384":{
639
+ "image_size": 384,
640
+ "patch_size": 16,
641
+ "width": 1536,
642
+ "layers": 40,
643
+ "heads": 16,
644
+ "mlp_ratio": 4,
645
+ "global_pool": "map",
646
+ "use_checkpoint": False,
647
+ },
648
+ }
649
+
650
+
651
+ def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bicubic'):
652
+ # interpolate position embedding
653
+ orig_size = 24
654
+ new_size = 128
655
+ pos_tokens = model.pos_embed
656
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, model.embed_dim).permute(0, 3, 1, 2)
657
+ pos_tokens = torch.nn.functional.interpolate(
658
+ pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False)
659
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
660
+ model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True)
661
+ return model
662
+
663
+
664
+ def create_siglip_vit(
665
+ model_name: str = "siglip_so400m_patch14_384",
666
+ select_layer: int = -1,
667
+ path: str = "",
668
+ gradient_checkpointing: bool = False,
669
+ **kwargs,
670
+ ):
671
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
672
+
673
+ if select_layer <= 0:
674
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
675
+ else:
676
+ layers = min(vision_cfg.layers, select_layer)
677
+
678
+ model = VisionTransformer(
679
+ img_size=2048,
680
+ patch_size=16,
681
+ embed_dim=vision_cfg.width,
682
+ depth=layers,
683
+ num_heads=vision_cfg.heads,
684
+ mlp_ratio=vision_cfg.mlp_ratio,
685
+ class_token=vision_cfg.class_token,
686
+ global_pool=vision_cfg.global_pool,
687
+ dynamic_img_pad=False,
688
+ strict_img_size=False,
689
+ ignore_head=kwargs.get("ignore_head", False),
690
+ weight_init=kwargs.get("weight_init", "skip"),
691
+ num_classes=0
692
+ )
693
+ model.config = vision_cfg
694
+ state_dict = torch.load(path, map_location="cpu")
695
+ model.load_state_dict(state_dict, strict=False)
696
+
697
+ if gradient_checkpointing:
698
+ model.set_grad_checkpointing(True)
699
+ return model
modeling_xomni.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from types import SimpleNamespace
3
+ from typing import Tuple, List, Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from huggingface_hub import hf_hub_download
9
+ from transformers import Qwen2ForCausalLM, AutoModel, AutoModelForCausalLM
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2RotaryEmbedding, Qwen2DecoderLayer, Qwen2Model, Qwen2PreTrainedModel
12
+
13
+ from .configuration_xomni import XOmniConfig
14
+ from .modeling_siglip_tokenizer import create_anyres_preprocess, SiglipTokenizer
15
+ from .modeling_siglip_flux import FluxTransformer2DModelWithSigLIP, FluxPipelineWithSigLIP
16
+ from .modeling_vit import create_siglip_vit
17
+
18
+
19
+ class XOmniDecoderLayer(Qwen2DecoderLayer):
20
+ def __init__(self, config: XOmniConfig, layer_idx: int):
21
+ super().__init__(config, layer_idx)
22
+ self.layer_idx = layer_idx
23
+ self.is_lm_layer = config.num_mm_adap_layers <= layer_idx < config.num_hidden_layers - config.num_mm_head_layers
24
+
25
+ def forward(
26
+ self,
27
+ hidden_states: torch.Tensor,
28
+ **kwargs,
29
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
30
+ hidden_states, multimodal_mask = torch.split(hidden_states, hidden_states.shape[-1] // 2, dim=-1)
31
+ if self.is_lm_layer:
32
+ output_hidden_states, *others = super().forward(hidden_states, **kwargs)
33
+ output_hidden_states = torch.cat([output_hidden_states, multimodal_mask], dim=-1)
34
+ return output_hidden_states, *others
35
+
36
+ # mm_hidden_states = torch.where(multimodal_mask.bool(), hidden_states, torch.zeros_like(hidden_states))
37
+ output_hidden_states, *others = super().forward(hidden_states, **kwargs)
38
+ output_hidden_states = torch.where(multimodal_mask.bool(), output_hidden_states, hidden_states)
39
+ output_hidden_states = torch.cat([output_hidden_states, multimodal_mask], dim=-1)
40
+ return output_hidden_states, *others
41
+
42
+
43
+ class XOmniModel(Qwen2Model, Qwen2PreTrainedModel):
44
+ model_type = "x-omni"
45
+ config_class = XOmniConfig
46
+
47
+ def __init__(self, config: XOmniConfig):
48
+ Qwen2PreTrainedModel.__init__(self, config)
49
+ self.padding_idx = -1
50
+ self.vocab_size = config.vocab_size
51
+
52
+ self.lm_embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
53
+ self.mm_embed_tokens = nn.Embedding(config.mm_vocab_size, config.hidden_size, self.padding_idx)
54
+
55
+ self.layers = nn.ModuleList(
56
+ [XOmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
57
+ )
58
+ self._attn_implementation = config._attn_implementation
59
+ self.lm_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
60
+ self.mm_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
61
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
62
+
63
+ self.gradient_checkpointing = False
64
+ # Initialize weights and apply final processing
65
+ self.post_init()
66
+
67
+ def get_input_embeddings(self):
68
+ return self.lm_embed_tokens
69
+
70
+ def set_input_embeddings(self, value):
71
+ self.lm_embed_tokens = value
72
+
73
+ def embed_tokens(self, input_ids):
74
+ (B, L), C = input_ids.shape, self.config.hidden_size
75
+ multimodal_mask = input_ids >= self.config.vocab_size
76
+ lm_input_ids = input_ids[~multimodal_mask][None, :]
77
+ mm_input_ids = input_ids[multimodal_mask][None, :] - self.config.vocab_size
78
+ lm_embeds = self.lm_embed_tokens(lm_input_ids)
79
+ mm_embeds = self.mm_embed_tokens(mm_input_ids)
80
+
81
+ inputs_embeds = lm_embeds.new_empty((B, L, C))
82
+ multimodal_mask = multimodal_mask[:, :, None].expand_as(inputs_embeds)
83
+ inputs_embeds[~multimodal_mask] = lm_embeds.reshape(-1)
84
+ inputs_embeds[multimodal_mask] = mm_embeds.reshape(-1)
85
+
86
+ inputs_embeds = torch.cat([inputs_embeds, multimodal_mask.to(inputs_embeds.dtype)], dim=-1)
87
+ return inputs_embeds
88
+
89
+ def norm(self, hidden_states):
90
+ hidden_states, multimodal_mask = torch.split(hidden_states, hidden_states.shape[-1] // 2, dim=-1)
91
+ return torch.where(multimodal_mask.bool(), self.mm_norm(hidden_states), self.lm_norm(hidden_states))
92
+
93
+
94
+ class XOmniForCausalLM(Qwen2ForCausalLM):
95
+ model_type = "x-omni"
96
+ config_class = XOmniConfig
97
+
98
+ _keys_to_ignore_on_load_missing = r'image_tokenizer\.*'
99
+
100
+ def __init__(self, config):
101
+ super().__init__(config)
102
+ self.model = XOmniModel(config)
103
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
104
+ self.mm_head = nn.Linear(config.hidden_size, config.mm_vocab_size, bias=False)
105
+
106
+ self.generation_mode = 'text'
107
+ # Initialize weights and apply final processing
108
+ self.post_init()
109
+
110
+ @property
111
+ def device(self):
112
+ return next(iter(self.parameters())).device
113
+
114
+ def init_vision(self, flux_pipe_path):
115
+ self.som_token = self.config.mm_special_tokens[0]
116
+ self.eom_token = self.config.mm_special_tokens[1]
117
+ self.img_token = self.config.mm_special_tokens[2]
118
+
119
+ self.vision_config = SimpleNamespace(**self.config.vision_config)
120
+ self.transform_config = SimpleNamespace(**self.vision_config.transform)
121
+ self.encoder_config = SimpleNamespace(**self.vision_config.encoder)
122
+ self.decoder_config = SimpleNamespace(**self.vision_config.decoder)
123
+
124
+ dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
125
+ self.vision_dtype = dtype_map[self.vision_config.dtype]
126
+
127
+ self.image_transform = create_anyres_preprocess(**self.vision_config.transform)
128
+
129
+ self.encoder_config.siglip_path = os.path.join(self.name_or_path, self.encoder_config.siglip_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.siglip_path)
130
+ self.encoder_config.projector_path = os.path.join(self.name_or_path, self.encoder_config.projector_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.projector_path)
131
+
132
+ self.image_tokenizer = SiglipTokenizer(**vars(self.encoder_config))
133
+ self.image_tokenizer.to(self.device, self.vision_dtype)
134
+
135
+ self.decoder_pipe = FluxPipelineWithSigLIP.from_pretrained(
136
+ flux_pipe_path,
137
+ torch_dtype=self.vision_dtype,
138
+ )
139
+ self.decoder_pipe.transformer = FluxTransformer2DModelWithSigLIP.from_pretrained(
140
+ self.name_or_path,
141
+ siglip_channels=self.encoder_config.z_channels,
142
+ torch_dtype=self.vision_dtype,
143
+ subfolder=self.decoder_config.model_path,
144
+ )
145
+
146
+ self.decoder_pipe.set_progress_bar_config(disable=True)
147
+ self.decoder_pipe.to(self.device)
148
+
149
+ def set_generation_mode(self, mode):
150
+ assert mode in ('text', 'image'), f'Invalid generation mode: {mode}'
151
+ self.generation_mode = mode
152
+
153
+ def mmencode(self, tokenizer, texts=None, images=None, **kwargs):
154
+ texts = texts or []
155
+ images = images or []
156
+ doc = ''
157
+ while len(texts) > 0 or len(images) > 0:
158
+ if len(texts) > 0:
159
+ doc += texts.pop(0)
160
+ if len(images) > 0:
161
+ doc += self.tokenize_image(images.pop(0))
162
+ return tokenizer.encode(doc, **kwargs)
163
+
164
+ def mmdecode(self, tokenizer, token_ids, force_text=None, **kwargs):
165
+ force_text = force_text or []
166
+ if isinstance(token_ids, torch.Tensor):
167
+ if len(token_ids.shape) == 2:
168
+ assert token_ids.shape[0] == 1
169
+ token_ids = token_ids[0]
170
+ assert len(token_ids.shape) == 1
171
+ else:
172
+ if not isinstance(token_ids[0], int):
173
+ assert len(token_ids) == 1
174
+ token_ids = token_ids[0]
175
+ assert isinstance(token_ids[0], int)
176
+
177
+ doc = tokenizer.decode(token_ids, **kwargs)
178
+ doc = doc.replace(tokenizer.pad_token, '')
179
+ doc = doc.replace('<SEP>', '')
180
+ texts, images = [], []
181
+ text_image_chunks = doc.split(self.eom_token)
182
+ for chunk in text_image_chunks:
183
+ text, image_str = chunk.split(self.som_token) \
184
+ if self.som_token in chunk else (chunk, '')
185
+ texts.append(text)
186
+ if self.img_token in image_str:
187
+ image_meta, token_str = image_str.split(self.img_token)
188
+ H, W = tuple(map(int, image_meta.split(' ')))
189
+ token_ids = list(map(
190
+ lambda x: int(x.split('>')[0]),
191
+ token_str.split('<MM-Token-')[1:H*W+1],
192
+ ))
193
+ if len(force_text) > 0:
194
+ image = self.detokenize_image([force_text.pop(0)], images, token_ids, (H, W))
195
+ else:
196
+ image = self.detokenize_image(texts, images, token_ids, (H, W))
197
+ images.append(image)
198
+ return texts, images
199
+
200
+ @torch.no_grad()
201
+ def tokenize_image(self, image):
202
+ assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.'
203
+
204
+ image_str = self.som_token
205
+ image = self.image_transform(image)
206
+ assert image is not None, f'Unsupported image aspect ratio (max {self.transform_config.max_aspect_ratio}) or image resolution is too low (min {self.transform_config.min_short_size})'
207
+
208
+ image = image[None, ...].to(self.device, self.vision_dtype)
209
+ tokens = self.image_tokenizer.encode(image)
210
+ B, H, W = tokens.shape
211
+ tokens = tokens.view(B, -1).cpu().tolist()[0]
212
+ token_str = ''.join(map(lambda x: '<MM-Token-{token_id}>'.format(token_id=x), tokens))
213
+ image_str = f'{self.som_token}{H} {W}{self.img_token}{token_str}{self.eom_token}'
214
+ return image_str
215
+
216
+ @torch.no_grad()
217
+ def detokenize_image(self, texts, images, token_ids, shape):
218
+ assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.'
219
+ assert len(texts) == 1 and len(images) == 0, 'Only support one image per sample.'
220
+ H, W = shape
221
+ tokens = torch.tensor(token_ids, device=self.device, dtype=torch.long)
222
+ latents = self.image_tokenizer.decode(tokens, (1, H, W, self.encoder_config.codebook_dim))
223
+ upscale_factor = self.decoder_config.upscale_factor
224
+ latents = latents.reshape(*latents.shape[:2], -1).transpose(1, 2).contiguous()
225
+ image = self.decoder_pipe(
226
+ latents,
227
+ [texts[0]],
228
+ negative_prompt=[''],
229
+ height=H * upscale_factor, width=W * upscale_factor,
230
+ num_inference_steps=self.decoder_config.num_inference_steps,
231
+ guidance_scale=1.0,
232
+ true_cfg_scale=self.decoder_config.cfg_scale,
233
+ true_cfg_scale_2=self.decoder_config.cfg_scale_2,
234
+ ).images[0]
235
+
236
+
237
+ return image
238
+
239
+ def forward(
240
+ self,
241
+ input_ids: torch.LongTensor = None,
242
+ attention_mask: Optional[torch.Tensor] = None,
243
+ position_ids: Optional[torch.LongTensor] = None,
244
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
245
+ inputs_embeds: Optional[torch.FloatTensor] = None,
246
+ labels: Optional[torch.LongTensor] = None,
247
+ use_cache: Optional[bool] = None,
248
+ output_attentions: Optional[bool] = None,
249
+ output_hidden_states: Optional[bool] = None,
250
+ return_dict: Optional[bool] = None,
251
+ cache_position: Optional[torch.LongTensor] = None,
252
+ num_logits_to_keep: int = 0,
253
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
254
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
255
+ output_hidden_states = (
256
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
257
+ )
258
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
259
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
260
+ outputs = self.model(
261
+ input_ids=input_ids,
262
+ attention_mask=attention_mask,
263
+ position_ids=position_ids,
264
+ past_key_values=past_key_values,
265
+ inputs_embeds=inputs_embeds,
266
+ use_cache=use_cache,
267
+ output_attentions=output_attentions,
268
+ output_hidden_states=output_hidden_states,
269
+ return_dict=return_dict,
270
+ cache_position=cache_position,
271
+ )
272
+
273
+ hidden_states = outputs[0]
274
+ hidden_states = hidden_states[:, -num_logits_to_keep:, :]
275
+ logits = hidden_states.new_full(
276
+ (*hidden_states.shape[:-1], self.config.vocab_size + self.config.mm_vocab_size),
277
+ torch.finfo(hidden_states.dtype).min
278
+ )
279
+ if self.generation_mode == 'text':
280
+ logits[:, :, :self.config.vocab_size] = self.lm_head(hidden_states)
281
+ else:
282
+ logits[:, :, self.config.vocab_size:self.config.vocab_size + self.config.image_vocab_size] = self.mm_head(hidden_states)[:, :, :self.config.image_vocab_size]
283
+
284
+ logits = logits.float()
285
+
286
+ loss = None
287
+ if labels is not None:
288
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
289
+ logits = logits.float()
290
+ # Shift so that tokens < n predict n
291
+ shift_logits = logits[..., :-1, :].contiguous()
292
+ shift_labels = labels[..., 1:].contiguous()
293
+ # Flatten the tokens
294
+ loss_fct = nn.CrossEntropyLoss()
295
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
296
+ shift_labels = shift_labels.view(-1)
297
+ # Enable model parallelism
298
+ shift_labels = shift_labels.to(shift_logits.device)
299
+ loss = loss_fct(shift_logits, shift_labels)
300
+
301
+ if not return_dict:
302
+ output = (logits,) + outputs[1:]
303
+ return (loss,) + output if loss is not None else output
304
+
305
+ return CausalLMOutputWithPast(
306
+ loss=loss,
307
+ logits=logits,
308
+ past_key_values=outputs.past_key_values,
309
+ hidden_states=outputs.hidden_states,
310
+ attentions=outputs.attentions,
311
+ )
312
+
313
+
314
+ AutoModel.register(XOmniConfig, XOmniModel)
315
+ AutoModelForCausalLM.register(XOmniConfig, XOmniForCausalLM)
special_tokens_map.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0f6cd9880d5d8ffdd4761e68e770a51845e4021966bc93d1558c67d83e55fe8
3
+ size 14625209
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
vit/siglip_vq.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:225aa2d41094ee6c83d2d184d2b6cc9cdadb4f72c1a94501f916cffd42d3b567
3
+ size 249612138
vit/vit_g.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:250be9ba1a52a5b1365cfd79276cbf022df63577eed5d9cc234f8603d06ef626
3
+ size 2376032176
vocab.json ADDED
The diff for this file is too large to render. See raw diff