EthanReid commited on
Commit
fc5dc52
·
1 Parent(s): 4f59175
Files changed (7) hide show
  1. README.md +12 -1
  2. config.py +1 -0
  3. layers.py +66 -2
  4. model.safetensors +2 -2
  5. moondream.py +15 -9
  6. text.py +43 -87
  7. weights.py +177 -220
README.md CHANGED
@@ -9,6 +9,10 @@ Moondream is a small vision language model designed to run efficiently everywher
9
 
10
  This repository contains the latest (**2025-04-14**) release of Moondream, as well as [historical releases](https://huggingface.co/vikhyatk/moondream2/blob/main/versions.txt). The model is updated frequently, so we recommend specifying a revision as shown below if you're using it in a production application.
11
 
 
 
 
 
12
 
13
  ### Usage
14
 
@@ -16,9 +20,11 @@ This repository contains the latest (**2025-04-14**) release of Moondream, as we
16
  from transformers import AutoModelForCausalLM, AutoTokenizer
17
  from PIL import Image
18
 
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  "vikhyatk/moondream2",
21
- revision="2025-04-14",
 
22
  trust_remote_code=True,
23
  # Uncomment to run on GPU.
24
  # device_map={"": "cuda"}
@@ -50,6 +56,11 @@ print(f"Found {len(points)} person(s)")
50
  ```
51
 
52
  ### Changelog
 
 
 
 
 
53
 
54
  **2025-04-15** ([full release notes](https://moondream.ai/blog/moondream-2025-04-14-release))
55
 
 
9
 
10
  This repository contains the latest (**2025-04-14**) release of Moondream, as well as [historical releases](https://huggingface.co/vikhyatk/moondream2/blob/main/versions.txt). The model is updated frequently, so we recommend specifying a revision as shown below if you're using it in a production application.
11
 
12
+ To use **quantized int4**, make sure to install the requirements:
13
+ ```
14
+ pip install -r https://depot.moondream.ai/transformers/requirements.txt
15
+ ```
16
 
17
  ### Usage
18
 
 
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
  from PIL import Image
22
 
23
+ # To run in float16, set revision_id = 2025-04-14
24
  model = AutoModelForCausalLM.from_pretrained(
25
  "vikhyatk/moondream2",
26
+ revision="int4_2025-04-14",
27
+ revision="2025-04-14
28
  trust_remote_code=True,
29
  # Uncomment to run on GPU.
30
  # device_map={"": "cuda"}
 
56
  ```
57
 
58
  ### Changelog
59
+ **int4-2025-04-15** ([full release notes](https://moondream.ai/blog/moondream-2025-04-14-release))
60
+ 1. Moondream uses a whole lot less memory (4.12 down to 2.47GB)
61
+ 2. Small device get a big speed up (44.54 to 67.84 tok/sec on a RTX 4050 Mobile)
62
+ 3. Improved spatial understanding (RealWorldQA up from 58.3 to 60.13)
63
+
64
 
65
  **2025-04-15** ([full release notes](https://moondream.ai/blog/moondream-2025-04-14-release))
66
 
config.py CHANGED
@@ -12,6 +12,7 @@ class TextConfig:
12
  n_heads: int = 32
13
  n_kv_heads: int = 32
14
  prefix_attn: int = 730
 
15
 
16
 
17
  @dataclass(frozen=True)
 
12
  n_heads: int = 32
13
  n_kv_heads: int = 32
14
  prefix_attn: int = 730
15
+ group_size: int = 128
16
 
17
 
18
  @dataclass(frozen=True)
layers.py CHANGED
@@ -1,7 +1,10 @@
 
 
 
 
1
  from dataclasses import dataclass
2
  from typing import Literal
3
-
4
- import torch
5
  from torch.nn import functional as F
6
 
7
 
@@ -15,6 +18,66 @@ class LinearWeights:
15
  bias: torch.Tensor
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
19
  return F.linear(x, w.weight, w.bias)
20
 
@@ -37,6 +100,7 @@ class MLPWeights:
37
 
38
 
39
  def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
 
40
  x = w.fc1(x)
41
  x = gelu_approx(x)
42
  x = w.fc2(x)
 
1
+ import bitblas
2
+ import torch
3
+ import torch.nn as nn
4
+
5
  from dataclasses import dataclass
6
  from typing import Literal
7
+ from bitblas.cache import OperatorCache
 
8
  from torch.nn import functional as F
9
 
10
 
 
18
  bias: torch.Tensor
19
 
20
 
21
+ class Linear(nn.Module):
22
+ """
23
+ Linear layer with support for bitblas quantization.
24
+ If dtype is torch.int8, it uses bitblas for quantization.
25
+ Otherwise, it uses a standard nn.Linear layer.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ in_features: int,
31
+ out_features: int,
32
+ bias: bool = True,
33
+ dtype: torch.dtype = None,
34
+ operator_cache: OperatorCache = None,
35
+ cache_dir: str = None,
36
+ group_size: int = 128,
37
+ ):
38
+ super().__init__()
39
+
40
+ if dtype == torch.int8:
41
+ self.linear = bitblas.Linear(
42
+ in_features=in_features,
43
+ out_features=out_features,
44
+ bias=bias,
45
+ with_zeros=True,
46
+ zeros_mode="original",
47
+ with_scaling=True,
48
+ A_dtype="float16",
49
+ W_dtype="uint4",
50
+ accum_dtype="float16",
51
+ out_dtype="float16",
52
+ fast_decoding=True,
53
+ enable_tuning=True,
54
+ operator_cache=operator_cache,
55
+ database_path=cache_dir,
56
+ group_size=group_size,
57
+ )
58
+ else:
59
+ self.linear = nn.Linear(
60
+ in_features=in_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ dtype=torch.float16,
64
+ )
65
+
66
+ def forward(self, x):
67
+ return self.linear(x)
68
+
69
+ @property
70
+ def weight(self) -> torch.Tensor:
71
+ try:
72
+ return self.linear.weight
73
+ except AttributeError:
74
+ return self.linear.qweight
75
+
76
+ @property
77
+ def bias(self) -> torch.Tensor:
78
+ return self.linear.bias
79
+
80
+
81
  def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
82
  return F.linear(x, w.weight, w.bias)
83
 
 
100
 
101
 
102
  def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
103
+
104
  x = w.fc1(x)
105
  x = gelu_approx(x)
106
  x = w.fc2(x)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:96dce588e4a319fde7af3c70fbf27e726f4850e22522d0fdc4b165d5e6003ad5
3
- size 3854538376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f86cdffeecef5dfab629bf93dbbf3b4ec480d5eaa4ab11b18714b92d76c4303
3
+ size 2080366176
moondream.py CHANGED
@@ -66,12 +66,16 @@ class MoondreamModel(nn.Module):
66
  def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
67
  super().__init__()
68
  self.config = config
 
 
69
 
70
  self.tokenizer = Tokenizer.from_pretrained(
71
  "vikhyatk/moondream2", revision="2025-01-09"
72
  )
 
73
  self.vision = build_vision_model(config.vision, dtype)
74
- self.text = build_text_model(config.text, dtype)
 
75
 
76
  # Region Model
77
  self.region = nn.ModuleDict(
@@ -125,11 +129,11 @@ class MoondreamModel(nn.Module):
125
  attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
126
  self.register_buffer("attn_mask", attn_mask, persistent=False)
127
 
128
- # Initialize KV caches.
129
- if setup_caches:
130
- self._setup_caches()
131
-
132
  def _setup_caches(self):
 
 
 
 
133
  c = self.config.text
134
  for b in self.text.blocks:
135
  b.kv_cache = KVCache(
@@ -163,11 +167,11 @@ class MoondreamModel(nn.Module):
163
 
164
  def compile(self):
165
  # TODO: vision_projection is not being compiled
166
- self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
167
- self._prefill = torch.compile(self._prefill, fullgraph=True)
168
- self._decode_one_tok = torch.compile(
169
- self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
170
  )
 
 
171
 
172
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
173
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
@@ -200,6 +204,7 @@ class MoondreamModel(nn.Module):
200
 
201
  # Run through text model in addition to the vision encoder, to minimize
202
  # re-computation if multiple queries are performed on this image.
 
203
  with torch.inference_mode():
204
  img_emb = self._run_vision_encoder(image)
205
  bos_emb = text_encoder(
@@ -235,6 +240,7 @@ class MoondreamModel(nn.Module):
235
  def _prefill_prompt(
236
  self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
237
  ):
 
238
  with torch.inference_mode():
239
  prompt_emb = text_encoder(prompt_tokens, self.text)
240
  torch._dynamo.mark_dynamic(prompt_emb, 1)
 
66
  def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
67
  super().__init__()
68
  self.config = config
69
+ self.dtype = dtype
70
+ self.setup_caches_flag = setup_caches
71
 
72
  self.tokenizer = Tokenizer.from_pretrained(
73
  "vikhyatk/moondream2", revision="2025-01-09"
74
  )
75
+
76
  self.vision = build_vision_model(config.vision, dtype)
77
+
78
+ self.text = None
79
 
80
  # Region Model
81
  self.region = nn.ModuleDict(
 
129
  attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
130
  self.register_buffer("attn_mask", attn_mask, persistent=False)
131
 
 
 
 
 
132
  def _setup_caches(self):
133
+ """Setup KV caches for the text model"""
134
+ if self.text is None:
135
+ return # Can't set up caches without text model
136
+
137
  c = self.config.text
138
  for b in self.text.blocks:
139
  b.kv_cache = KVCache(
 
167
 
168
  def compile(self):
169
  # TODO: vision_projection is not being compiled
170
+ self._vis_enc = torch.compile(
171
+ self._vis_enc, fullgraph=False, mode="reduce-overhead"
 
 
172
  )
173
+ self._prefill = torch.compile(self._prefill)
174
+ self._decode_one_tok = torch.compile(self._decode_one_tok)
175
 
176
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
177
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
 
204
 
205
  # Run through text model in addition to the vision encoder, to minimize
206
  # re-computation if multiple queries are performed on this image.
207
+
208
  with torch.inference_mode():
209
  img_emb = self._run_vision_encoder(image)
210
  bos_emb = text_encoder(
 
240
  def _prefill_prompt(
241
  self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
242
  ):
243
+
244
  with torch.inference_mode():
245
  prompt_emb = text_encoder(prompt_tokens, self.text)
246
  torch._dynamo.mark_dynamic(prompt_emb, 1)
text.py CHANGED
@@ -2,8 +2,9 @@ import torch
2
  import torch.nn as nn
3
 
4
  from torch.nn import functional as F
 
5
 
6
- from .layers import layer_norm, mlp
7
  from .rope import apply_rotary_emb, precompute_freqs_cis
8
  from .config import TextConfig
9
 
@@ -26,6 +27,7 @@ def attn(
26
  head_dim = d_model // n_heads
27
 
28
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
 
29
  q_dim = n_heads * head_dim
30
  kv_dim = n_kv_heads * head_dim
31
 
@@ -55,71 +57,6 @@ def attn(
55
  return out
56
 
57
 
58
- def _attn(
59
- x: torch.Tensor,
60
- w: torch.Tensor,
61
- freqs_cis: torch.Tensor,
62
- attn_mask: torch.Tensor,
63
- n_heads: int,
64
- n_kv_heads: int,
65
- ):
66
- bsz, q_len, d_model = x.shape
67
- head_dim = d_model // n_heads
68
- pos = 0
69
-
70
- qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
71
- q_dim = n_heads * head_dim
72
- kv_dim = n_kv_heads * head_dim
73
-
74
- q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
75
- k = (
76
- qkv_out[..., q_dim : q_dim + kv_dim]
77
- .view(bsz, q_len, n_kv_heads, head_dim)
78
- .transpose(1, 2)
79
- )
80
- v = (
81
- qkv_out[..., q_dim + kv_dim :]
82
- .view(bsz, q_len, n_kv_heads, head_dim)
83
- .transpose(1, 2)
84
- )
85
-
86
- position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
87
- q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
88
- k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
89
- out = F.scaled_dot_product_attention(
90
- q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
91
- )
92
- out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
93
- out = w.proj(out)
94
- return out
95
-
96
-
97
- def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
98
- hidden_BTC = inputs_embeds
99
-
100
- bsz, q_len, d_model = inputs_embeds.shape
101
- attn_mask = torch.zeros(q_len, q_len)
102
- attn_mask[:730, :730] = 1
103
- for i in range(730, q_len):
104
- attn_mask[i, : i + 1] = 1
105
- attn_mask = attn_mask.to(dtype=torch.bool)
106
-
107
- for i, block in enumerate(w.blocks):
108
- l_in = layer_norm(hidden_BTC, block.ln)
109
- l_attn = _attn(
110
- x=l_in,
111
- w=block.attn,
112
- freqs_cis=w.freqs_cis,
113
- attn_mask=attn_mask,
114
- n_heads=config.n_heads,
115
- n_kv_heads=config.n_kv_heads,
116
- )
117
- l_mlp = mlp(l_in, block.mlp)
118
- hidden_BTC = hidden_BTC + l_attn + l_mlp
119
-
120
- return hidden_BTC
121
-
122
-
123
  def text_decoder(
124
  x: torch.Tensor,
125
  w: nn.Module,
@@ -139,6 +76,7 @@ def text_decoder(
139
  n_kv_heads=config.n_kv_heads,
140
  position_ids=position_ids,
141
  )
 
142
  l_mlp = mlp(l_in, block.mlp)
143
  x = x + l_attn + l_mlp
144
 
@@ -152,38 +90,54 @@ def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
152
  return logits
153
 
154
 
155
- def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
156
- hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
157
- logits = w.lm_head(hidden_BTC)
158
- return logits
159
-
160
-
161
- def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
 
 
162
  qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  text = nn.ModuleDict(
165
  {
166
  "blocks": nn.ModuleList(
167
  [
168
  nn.ModuleDict(
169
  {
170
- "ln": nn.LayerNorm(config.dim, dtype=dtype),
171
  "attn": nn.ModuleDict(
172
  {
173
- "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
174
- "proj": nn.Linear(
175
- config.dim, config.dim, dtype=dtype
176
- ),
177
  }
178
  ),
179
  "mlp": nn.ModuleDict(
180
  {
181
- "fc1": nn.Linear(
182
- config.dim, config.ff_dim, dtype=dtype
183
- ),
184
- "fc2": nn.Linear(
185
- config.ff_dim, config.dim, dtype=dtype
186
- ),
187
  }
188
  ),
189
  }
@@ -191,11 +145,13 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
191
  for _ in range(config.n_layers)
192
  ]
193
  ),
194
- "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
195
- "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
196
  }
197
  )
198
- text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
 
 
199
  text.register_buffer(
200
  "freqs_cis",
201
  precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
 
2
  import torch.nn as nn
3
 
4
  from torch.nn import functional as F
5
+ from bitblas.cache import OperatorCache
6
 
7
+ from .layers import layer_norm, mlp, Linear
8
  from .rope import apply_rotary_emb, precompute_freqs_cis
9
  from .config import TextConfig
10
 
 
27
  head_dim = d_model // n_heads
28
 
29
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
30
+
31
  q_dim = n_heads * head_dim
32
  kv_dim = n_kv_heads * head_dim
33
 
 
57
  return out
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def text_decoder(
61
  x: torch.Tensor,
62
  w: nn.Module,
 
76
  n_kv_heads=config.n_kv_heads,
77
  position_ids=position_ids,
78
  )
79
+
80
  l_mlp = mlp(l_in, block.mlp)
81
  x = x + l_attn + l_mlp
82
 
 
90
  return logits
91
 
92
 
93
+ def build_text_model(
94
+ config: TextConfig,
95
+ linear_dtype: torch.dtype = torch.float16,
96
+ layernorm_dtype: torch.dtype = torch.float16,
97
+ ) -> nn.Module:
98
+ # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm
99
+ print(
100
+ "Initializing quantized backend. This only has to run once, but may take a few minutes."
101
+ )
102
  qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
103
 
104
+ operator_cache = None
105
+ cache_dir = None
106
+ group_size = None
107
+ if linear_dtype == torch.int8:
108
+
109
+ operator_cache = OperatorCache()
110
+ cache_dir = "./cache"
111
+ group_size = config.group_size
112
+
113
+ def create_linear(in_features, out_features, dtype=linear_dtype):
114
+ # factory function for creating Linear layers so we dont have to pass everything again and again
115
+ return Linear(
116
+ in_features=in_features,
117
+ out_features=out_features,
118
+ dtype=dtype,
119
+ operator_cache=operator_cache,
120
+ cache_dir=cache_dir,
121
+ group_size=group_size,
122
+ )
123
+
124
  text = nn.ModuleDict(
125
  {
126
  "blocks": nn.ModuleList(
127
  [
128
  nn.ModuleDict(
129
  {
130
+ "ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
131
  "attn": nn.ModuleDict(
132
  {
133
+ "qkv": create_linear(config.dim, qkv_dim),
134
+ "proj": create_linear(config.dim, config.dim),
 
 
135
  }
136
  ),
137
  "mlp": nn.ModuleDict(
138
  {
139
+ "fc1": create_linear(config.dim, config.ff_dim),
140
+ "fc2": create_linear(config.ff_dim, config.dim),
 
 
 
 
141
  }
142
  ),
143
  }
 
145
  for _ in range(config.n_layers)
146
  ]
147
  ),
148
+ "post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
149
+ "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=layernorm_dtype),
150
  }
151
  )
152
+ text.wte = nn.Parameter(
153
+ torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)
154
+ )
155
  text.register_buffer(
156
  "freqs_cis",
157
  precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
weights.py CHANGED
@@ -1,61 +1,25 @@
1
  import safetensors
2
  import torch
3
  import torch.nn as nn
 
4
 
5
  from contextlib import contextmanager
6
- from dataclasses import dataclass
7
  from typing import Callable, List
8
 
9
- from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights
 
10
 
11
 
12
- @dataclass
13
- class VisionBlock:
14
- ln1: LayerNormWeights
15
- attn: AttentionWeights
16
- ln2: LayerNormWeights
17
- mlp: MLPWeights
18
-
19
-
20
- @dataclass
21
- class VisionModel:
22
- patch_emb: LinearWeights
23
- pos_emb: torch.Tensor
24
- blocks: List[VisionBlock]
25
- post_ln: LayerNormWeights
26
- proj_mlp: MLPWeights
27
-
28
-
29
- @dataclass
30
- class TextBlock:
31
- ln: LayerNormWeights
32
- attn: AttentionWeights
33
- mlp: MLPWeights
34
-
35
-
36
- @dataclass
37
- class TextModel:
38
- wte: torch.Tensor
39
- blocks: List[TextBlock]
40
- post_ln: LayerNormWeights
41
- lm_head: LinearWeights
42
-
43
-
44
- @dataclass
45
- class RegionModel:
46
- coord_features: torch.Tensor
47
- coord_encoder: LinearWeights
48
- coord_decoder: MLPWeights
49
- size_features: torch.Tensor
50
- size_encoder: LinearWeights
51
- size_decoder: MLPWeights
52
-
53
-
54
- @dataclass
55
- class MoondreamModel:
56
- vision: VisionModel
57
- text: TextModel
58
- region: RegionModel
59
 
60
 
61
  @contextmanager
@@ -79,199 +43,192 @@ def safetensors_open(safetensors_file: str):
79
  yield get_tensor
80
 
81
 
82
- def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
 
 
 
 
83
  """Internal function to load weights using a tensor getter function."""
84
  model = model.to(dtype=torch.float16)
85
 
86
- # Vision Model
87
- model.vision["patch_emb"].weight.data.copy_(
88
- get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.weight")
89
- )
90
- model.vision["patch_emb"].bias.data.copy_(
91
- get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.bias")
92
- )
93
- model.vision.pos_emb.data.copy_(
94
- get_tensor("vision_encoder.encoder.model.visual.pos_embed")
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  for i in range(len(model.vision["blocks"])):
98
  prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
99
-
100
- # Layer norms
101
- model.vision["blocks"][i]["ln1"].weight.data.copy_(
102
- get_tensor(f"{prefix}.norm1.weight")
103
- )
104
- model.vision["blocks"][i]["ln1"].bias.data.copy_(
105
- get_tensor(f"{prefix}.norm1.bias")
106
- )
107
- model.vision["blocks"][i]["ln2"].weight.data.copy_(
108
- get_tensor(f"{prefix}.norm2.weight")
109
- )
110
- model.vision["blocks"][i]["ln2"].bias.data.copy_(
111
- get_tensor(f"{prefix}.norm2.bias")
112
- )
113
-
114
- # Attention
115
- model.vision["blocks"][i]["attn"]["qkv"].weight.data.copy_(
116
- get_tensor(f"{prefix}.attn.qkv.weight")
117
- )
118
- model.vision["blocks"][i]["attn"]["qkv"].bias.data.copy_(
119
- get_tensor(f"{prefix}.attn.qkv.bias")
120
- )
121
- model.vision["blocks"][i]["attn"]["proj"].weight.data.copy_(
122
- get_tensor(f"{prefix}.attn.proj.weight")
123
- )
124
- model.vision["blocks"][i]["attn"]["proj"].bias.data.copy_(
125
- get_tensor(f"{prefix}.attn.proj.bias")
126
- )
127
-
128
- # MLP
129
- model.vision["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
130
- get_tensor(f"{prefix}.mlp.fc1.weight")
131
- )
132
- model.vision["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
133
- get_tensor(f"{prefix}.mlp.fc1.bias")
134
- )
135
- model.vision["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
136
- get_tensor(f"{prefix}.mlp.fc2.weight")
137
- )
138
- model.vision["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
139
- get_tensor(f"{prefix}.mlp.fc2.bias")
140
- )
141
-
142
- model.vision["post_ln"].weight.data.copy_(
143
- get_tensor("vision_encoder.encoder.model.visual.norm.weight")
144
- )
145
- model.vision["post_ln"].bias.data.copy_(
146
- get_tensor("vision_encoder.encoder.model.visual.norm.bias")
147
- )
148
-
149
- model.vision["proj_mlp"]["fc1"].weight.data.copy_(
150
- get_tensor("vision_encoder.projection.mlp.fc1.weight")
151
- )
152
- model.vision["proj_mlp"]["fc1"].bias.data.copy_(
153
- get_tensor("vision_encoder.projection.mlp.fc1.bias")
154
- )
155
- model.vision["proj_mlp"]["fc2"].weight.data.copy_(
156
- get_tensor("vision_encoder.projection.mlp.fc2.weight")
157
- )
158
- model.vision["proj_mlp"]["fc2"].bias.data.copy_(
159
- get_tensor("vision_encoder.projection.mlp.fc2.bias")
160
  )
 
161
 
162
- # Text Model
163
- model.text.wte.data.copy_(get_tensor("text_model.transformer.embd.wte.weight"))
164
 
165
- for i in range(len(model.text["blocks"])):
166
- prefix = f"text_model.transformer.h.{i}"
 
 
167
 
168
- # Layer norm
169
- model.text["blocks"][i]["ln"].weight.data.copy_(
170
- get_tensor(f"{prefix}.ln.weight")
171
  )
172
- model.text["blocks"][i]["ln"].bias.data.copy_(get_tensor(f"{prefix}.ln.bias"))
173
 
174
- # Attention
175
- model.text["blocks"][i]["attn"]["qkv"].weight.data.copy_(
176
- get_tensor(f"{prefix}.mixer.Wqkv.weight")
177
- )
178
- model.text["blocks"][i]["attn"]["qkv"].bias.data.copy_(
179
- get_tensor(f"{prefix}.mixer.Wqkv.bias")
180
- )
181
- model.text["blocks"][i]["attn"]["proj"].weight.data.copy_(
182
- get_tensor(f"{prefix}.mixer.out_proj.weight")
183
- )
184
- model.text["blocks"][i]["attn"]["proj"].bias.data.copy_(
185
- get_tensor(f"{prefix}.mixer.out_proj.bias")
186
- )
187
 
188
- # MLP
189
- model.text["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
190
- get_tensor(f"{prefix}.mlp.fc1.weight")
191
- )
192
- model.text["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
193
- get_tensor(f"{prefix}.mlp.fc1.bias")
194
- )
195
- model.text["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
196
- get_tensor(f"{prefix}.mlp.fc2.weight")
197
- )
198
- model.text["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
199
- get_tensor(f"{prefix}.mlp.fc2.bias")
200
- )
201
 
202
- model.text["post_ln"].weight.data.copy_(get_tensor("text_model.lm_head.ln.weight"))
203
- model.text["post_ln"].bias.data.copy_(get_tensor("text_model.lm_head.ln.bias"))
 
 
 
204
 
205
- model.text["lm_head"].weight.data.copy_(
206
- get_tensor("text_model.lm_head.linear.weight")
207
- )
208
- model.text["lm_head"].bias.data.copy_(get_tensor("text_model.lm_head.linear.bias"))
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- # Region Model
211
- model.region.coord_features.data.copy_(
212
- get_tensor("region_model.coordinate_features.weight").T
213
- )
214
- model.region["coord_encoder"].weight.data.copy_(
215
- get_tensor("region_model.coordinate_encoder.weight")
216
- )
217
- model.region["coord_encoder"].bias.data.copy_(
218
- get_tensor("region_model.coordinate_encoder.bias")
219
- )
220
 
221
- model.region["coord_decoder"]["fc1"].weight.data.copy_(
222
- get_tensor("region_model.coordinate_decoder.fc1.weight")
223
- )
224
- model.region["coord_decoder"]["fc1"].bias.data.copy_(
225
- get_tensor("region_model.coordinate_decoder.fc1.bias")
226
- )
227
- model.region["coord_decoder"]["fc2"].weight.data.copy_(
228
- get_tensor("region_model.coordinate_decoder.fc2.weight")
229
- )
230
- model.region["coord_decoder"]["fc2"].bias.data.copy_(
231
- get_tensor("region_model.coordinate_decoder.fc2.bias")
232
  )
233
 
234
- model.region.size_features.data.copy_(
235
- get_tensor("region_model.size_features.weight").T
236
- )
237
- model.region["size_encoder"].weight.data.copy_(
238
- get_tensor("region_model.size_encoder.weight")
239
- )
240
- model.region["size_encoder"].bias.data.copy_(
241
- get_tensor("region_model.size_encoder.bias")
242
- )
243
 
244
- model.region["size_decoder"]["fc1"].weight.data.copy_(
245
- get_tensor("region_model.size_decoder.fc1.weight")
246
- )
247
- model.region["size_decoder"]["fc1"].bias.data.copy_(
248
- get_tensor("region_model.size_decoder.fc1.bias")
249
- )
250
- model.region["size_decoder"]["fc2"].weight.data.copy_(
251
- get_tensor("region_model.size_decoder.fc2.weight")
252
  )
253
- model.region["size_decoder"]["fc2"].bias.data.copy_(
254
- get_tensor("region_model.size_decoder.fc2.bias")
255
- )
256
-
257
 
258
- def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
259
- """Load weights from a safetensors file into a MoondreamModel instance."""
260
- with safetensors_open(weights_file) as get_tensor:
261
- # Wrap the get_tensor function to handle key normalization
262
- name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
263
- _load_weights(lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model)
264
-
265
-
266
- def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
267
- """Load weights from a PyTorch file into a MoondreamModel instance."""
268
- device = str(torch.empty(0).device)
269
- tensors = torch.load(weights_file, map_location=device, weights_only=True)
270
- tensors = {
271
- k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
272
- for k, v in tensors.items()
273
- }
274
- _load_weights(lambda x: tensors[x], model)
275
 
276
 
277
  def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
 
1
  import safetensors
2
  import torch
3
  import torch.nn as nn
4
+ import re
5
 
6
  from contextlib import contextmanager
 
7
  from typing import Callable, List
8
 
9
+ from .text import build_text_model
10
+ from .config import TextConfig
11
 
12
 
13
+ # Our custom linear has an module named linear, so we add linear to the name
14
+ def add_linear_to_key(k: str) -> str:
15
+ k = k.replace("model.", "")
16
+ if k.startswith("text.") and ".linear." not in k:
17
+ k = re.sub(
18
+ r"(attn\.(?:qkv|proj)|mlp\.fc[12])\.(weight|bias)$",
19
+ r"\1.linear.\2",
20
+ k,
21
+ )
22
+ return k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  @contextmanager
 
43
  yield get_tensor
44
 
45
 
46
+ def _load_weights(
47
+ get_tensor: Callable[[str], torch.Tensor],
48
+ model: nn.Module,
49
+ is_quantized: bool = False,
50
+ ) -> None:
51
  """Internal function to load weights using a tensor getter function."""
52
  model = model.to(dtype=torch.float16)
53
 
54
+ vision = model.vision
55
+ region = model.region
56
+
57
+ weight_map = {
58
+ "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
59
+ "patch_emb"
60
+ ].weight,
61
+ "vision_encoder.encoder.model.visual.patch_embed.linear.bias": vision[
62
+ "patch_emb"
63
+ ].bias,
64
+ "vision_encoder.encoder.model.visual.pos_embed": vision.pos_emb,
65
+ "vision_encoder.encoder.model.visual.norm.weight": vision["post_ln"].weight,
66
+ "vision_encoder.encoder.model.visual.norm.bias": vision["post_ln"].bias,
67
+ "vision_encoder.projection.mlp.fc1.weight": vision["proj_mlp"]["fc1"].weight,
68
+ "vision_encoder.projection.mlp.fc1.bias": vision["proj_mlp"]["fc1"].bias,
69
+ "vision_encoder.projection.mlp.fc2.weight": vision["proj_mlp"]["fc2"].weight,
70
+ "vision_encoder.projection.mlp.fc2.bias": vision["proj_mlp"]["fc2"].bias,
71
+ "text_model.transformer.embd.wte.weight": model.text.wte,
72
+ "text_model.lm_head.ln.weight": model.text["post_ln"].weight,
73
+ "text_model.lm_head.ln.bias": model.text["post_ln"].bias,
74
+ "text_model.lm_head.linear.weight": model.text["lm_head"].weight,
75
+ "text_model.lm_head.linear.bias": model.text["lm_head"].bias,
76
+ "region_model.coordinate_encoder.weight": region["coord_encoder"].weight,
77
+ "region_model.coordinate_encoder.bias": region["coord_encoder"].bias,
78
+ "region_model.coordinate_decoder.fc1.weight": region["coord_decoder"][
79
+ "fc1"
80
+ ].weight,
81
+ "region_model.coordinate_decoder.fc1.bias": region["coord_decoder"]["fc1"].bias,
82
+ "region_model.coordinate_decoder.fc2.weight": region["coord_decoder"][
83
+ "fc2"
84
+ ].weight,
85
+ "region_model.coordinate_decoder.fc2.bias": region["coord_decoder"]["fc2"].bias,
86
+ "region_model.size_encoder.weight": region["size_encoder"].weight,
87
+ "region_model.size_encoder.bias": region["size_encoder"].bias,
88
+ "region_model.size_decoder.fc1.weight": region["size_decoder"]["fc1"].weight,
89
+ "region_model.size_decoder.fc1.bias": region["size_decoder"]["fc1"].bias,
90
+ "region_model.size_decoder.fc2.weight": region["size_decoder"]["fc2"].weight,
91
+ "region_model.size_decoder.fc2.bias": region["size_decoder"]["fc2"].bias,
92
+ }
93
 
94
  for i in range(len(model.vision["blocks"])):
95
  prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
96
+ blk = model.vision["blocks"][i]
97
+ weight_map.update(
98
+ {
99
+ f"{prefix}.norm1.weight": blk["ln1"].weight,
100
+ f"{prefix}.norm1.bias": blk["ln1"].bias,
101
+ f"{prefix}.norm2.weight": blk["ln2"].weight,
102
+ f"{prefix}.norm2.bias": blk["ln2"].bias,
103
+ f"{prefix}.attn.qkv.weight": blk["attn"]["qkv"].weight,
104
+ f"{prefix}.attn.qkv.bias": blk["attn"]["qkv"].bias,
105
+ f"{prefix}.attn.proj.weight": blk["attn"]["proj"].weight,
106
+ f"{prefix}.attn.proj.bias": blk["attn"]["proj"].bias,
107
+ f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
108
+ f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
109
+ f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
110
+ f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
111
+ }
112
+ )
113
+
114
+ if not is_quantized:
115
+ for i in range(len(model.text["blocks"])):
116
+ prefix = f"text_model.transformer.h.{i}"
117
+ blk = model.text["blocks"][i]
118
+ weight_map.update(
119
+ {
120
+ f"{prefix}.ln.weight": blk["ln"].weight,
121
+ f"{prefix}.ln.bias": blk["ln"].bias,
122
+ f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
123
+ f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
124
+ f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
125
+ f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
126
+ f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
127
+ f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
128
+ f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
129
+ f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
130
+ }
131
+ )
132
+ else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
133
+ for i in range(len(model.text["blocks"])):
134
+ prefix = f"text_model.transformer.h.{i}"
135
+ blk = model.text["blocks"][i]
136
+ weight_map.update(
137
+ {
138
+ f"{prefix}.ln.qweight": blk["ln"].weight,
139
+ f"{prefix}.ln.bias": blk["ln"].bias,
140
+ f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
141
+ f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
142
+ f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
143
+ f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
144
+ f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
145
+ f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
146
+ f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
147
+ f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
148
+ }
149
+ )
150
+
151
+ for key, tensor in weight_map.items():
152
+ tensor.data.copy_(get_tensor(key))
153
+
154
+ region.coord_features.data.copy_(
155
+ get_tensor("region_model.coordinate_features.weight").T
 
156
  )
157
+ region.size_features.data.copy_(get_tensor("region_model.size_features.weight").T)
158
 
 
 
159
 
160
+ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
161
+ """Load weights from a safetensors file into a MoondreamModel instance."""
162
+ with safetensors_open(weights_file) as get_tensor:
163
+ all_keys = get_tensor.keys()
164
 
165
+ is_quantized = any(
166
+ ".qweight" in key or "_quantized" in key or "quant." in key
167
+ for key in all_keys
168
  )
 
169
 
170
+ if "text_model.transformer.h.0.ln.weight" in all_keys:
171
+ layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype
172
+ else:
173
+ layernorm_dtype = torch.float16
 
 
 
 
 
 
 
 
 
174
 
175
+ linear_dtype = torch.int8 if is_quantized else torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ model.text = build_text_model(
178
+ TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype
179
+ )
180
+ if model.setup_caches_flag:
181
+ model._setup_caches()
182
 
183
+ if (
184
+ "vision.blocks.0.attn.proj.bias" in all_keys
185
+ or "model.vision.blocks.0.attn.proj.bias" in all_keys
186
+ ):
187
+ with safetensors_open(weights_file) as get_tensor:
188
+ tensors = {add_linear_to_key(k): get_tensor(k) for k in all_keys}
189
+ model.load_state_dict(tensors, strict=False)
190
+ else:
191
+ # Wrap the get_tensor function to handle key normalization
192
+ name_map = {k.replace("._orig_mod", ""): k for k in all_keys}
193
+ _load_weights(
194
+ lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
195
+ model,
196
+ is_quantized,
197
+ )
198
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
201
+ """Load weights from a PyTorch file into a MoondreamModel instance."""
202
+ tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
203
+ all_keys = tensors.keys()
204
+ is_quantized = any(
205
+ ".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys
 
 
 
 
 
206
  )
207
 
208
+ if "text.blocks.0.ln.weight" in all_keys:
209
+ layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype
210
+ else:
211
+ layernorm_dtype = torch.float16
 
 
 
 
 
212
 
213
+ linear_dtype = torch.int8 if is_quantized else torch.float16
214
+ model.text = build_text_model(
215
+ TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype
 
 
 
 
 
216
  )
217
+ if model.setup_caches_flag:
218
+ model._setup_caches()
 
 
219
 
220
+ if (
221
+ "vision.blocks.0.attn.proj.bias" in all_keys
222
+ or "model.vision.blocks.0.attn.proj.bias" in all_keys
223
+ ):
224
+ tensors = {add_linear_to_key(k): v for k, v in tensors.items()}
225
+ model.load_state_dict(tensors, strict=False)
226
+ else:
227
+ tensors = {
228
+ k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
229
+ for k, v in tensors.items()
230
+ }
231
+ _load_weights(lambda x: tensors[x], model, is_quantized)
 
 
 
 
 
232
 
233
 
234
  def load_weights_into_model(weights_file: str, model: nn.Module) -> None: