bowenbaoamd commited on
Commit
da5a039
·
1 Parent(s): d46fba6

Delete files modeling_grok1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_grok1.py +0 -946
modeling_grok1.py DELETED
@@ -1,946 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from transformers.modeling_utils import PreTrainedModel
7
- from transformers.utils import logging
8
-
9
- try:
10
- from transformers.modeling_attn_mask_utils import \
11
- _prepare_4d_causal_attention_mask
12
-
13
- HAS_MASK_UTILS = True
14
- except ImportError:
15
- HAS_MASK_UTILS = False
16
-
17
- from .configuration_grok1 import Grok1Config
18
- from .modeling_grok1_outputs import (MoeCausalLMOutputWithPast,
19
- MoeModelOutputWithPast)
20
-
21
- logger = logging.get_logger(__name__)
22
-
23
-
24
- # copied from https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/mixtral/modeling_mixtral.py
25
- def load_balancing_loss_func(
26
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2
27
- ) -> float:
28
- r"""
29
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
30
-
31
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
32
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
33
- experts is too unbalanced.
34
-
35
- Args:
36
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
37
- Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts].
38
- num_experts (`int`, *optional*):
39
- Number of experts
40
-
41
- Returns:
42
- The auxiliary loss.
43
- """
44
- if gate_logits is None:
45
- return 0
46
-
47
- if isinstance(gate_logits, tuple):
48
- # cat along the layers?
49
- compute_device = gate_logits[0].device
50
- gate_logits = torch.cat(
51
- [gate.to(compute_device) for gate in gate_logits], dim=0
52
- )
53
-
54
- routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
55
- routing_weights = routing_weights.softmax(dim=-1)
56
-
57
- # cast the expert indices to int64, otherwise one-hot encoding will fail
58
- if selected_experts.dtype != torch.int64:
59
- selected_experts = selected_experts.to(torch.int64)
60
-
61
- if len(selected_experts.shape) == 2:
62
- selected_experts = selected_experts.unsqueeze(2)
63
-
64
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
65
-
66
- # For a given token, determine if it was routed to a given expert.
67
- expert_mask = torch.max(expert_mask, axis=-2).values
68
-
69
- # cast to float32 otherwise mean will fail
70
- expert_mask = expert_mask.to(torch.float32)
71
- tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
72
-
73
- router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1)
74
- return torch.mean(
75
- tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)
76
- ) * (num_experts**2)
77
-
78
-
79
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
80
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
81
- """
82
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
83
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
84
- """
85
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
86
- if n_rep == 1:
87
- return hidden_states
88
- hidden_states = hidden_states[:, :, None, :, :].expand(
89
- batch, num_key_value_heads, n_rep, slen, head_dim
90
- )
91
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
92
-
93
-
94
- class RMSNorm(nn.Module):
95
- def __init__(
96
- self,
97
- hidden_size: int,
98
- eps: float = 1e-5,
99
- create_scale: bool = True,
100
- ) -> None:
101
- super().__init__()
102
- self.variance_epsilon = eps
103
- if create_scale:
104
- self.weight = nn.Parameter(torch.zeros(hidden_size))
105
- else:
106
- self.weight = 1.0
107
-
108
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
109
- input_dtype = hidden_states.dtype
110
- hidden_states = hidden_states.to(torch.float32)
111
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
112
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
113
- hidden_states = self.weight * hidden_states
114
- return hidden_states.to(input_dtype)
115
-
116
-
117
- class RotaryEmbedding(nn.Module):
118
- def __init__(
119
- self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
120
- ) -> None:
121
- super().__init__()
122
- assert dim % 2 == 0
123
- self.dim = dim
124
- self.max_position_embeddings = max_position_embeddings
125
- self.base = base
126
- inv_freq = 1.0 / (
127
- self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
128
- )
129
- self.register_buffer("inv_freq", inv_freq, persistent=False)
130
-
131
- self._set_cos_sin_cache(
132
- seq_len=max_position_embeddings,
133
- device=self.inv_freq.device,
134
- dtype=torch.get_default_dtype(),
135
- )
136
-
137
- def _set_cos_sin_cache(self, seq_len, device, dtype):
138
- self.max_seq_len_cached = seq_len
139
- t = torch.arange(
140
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
141
- )
142
-
143
- freqs = torch.outer(t, self.inv_freq)
144
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
145
- emb = torch.cat((freqs, freqs), dim=-1)
146
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
147
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
148
-
149
- def forward(self, x, seq_len=None):
150
- # x: [bs, num_attention_heads, seq_len, head_size]
151
- if seq_len > self.max_seq_len_cached:
152
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
153
-
154
- return (
155
- self.cos_cached[:seq_len].to(dtype=x.dtype),
156
- self.sin_cached[:seq_len].to(dtype=x.dtype),
157
- )
158
-
159
-
160
- # Copied from transformers.models.llama.modeling_llama.rotate_half
161
- def rotate_half(x):
162
- """Rotates half the hidden dims of the input."""
163
- x1 = x[..., : x.shape[-1] // 2]
164
- x2 = x[..., x.shape[-1] // 2 :]
165
- return torch.cat((-x2, x1), dim=-1)
166
-
167
-
168
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
169
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
170
- """Applies Rotary Position Embedding to the query and key tensors.
171
-
172
- Args:
173
- q (`torch.Tensor`): The query tensor.
174
- k (`torch.Tensor`): The key tensor.
175
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
176
- sin (`torch.Tensor`): The sine part of the rotary embedding.
177
- position_ids (`torch.Tensor`):
178
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
179
- used to pass offsetted position ids when working with a KV-cache.
180
- unsqueeze_dim (`int`, *optional*, defaults to 1):
181
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
182
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
183
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
184
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
185
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
186
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
187
- Returns:
188
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
189
- """
190
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
191
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
192
- q_embed = (q * cos) + (rotate_half(q) * sin)
193
- k_embed = (k * cos) + (rotate_half(k) * sin)
194
- return q_embed, k_embed
195
-
196
-
197
- class MultiHeadAttention(nn.Module):
198
- def __init__(
199
- self,
200
- hidden_size: int,
201
- num_heads: int,
202
- num_key_value_heads: Optional[int] = None,
203
- max_position_embeddings: int = 2048,
204
- attn_output_multiplier: float = 1.0,
205
- max_attn_val: float = 30.0,
206
- ):
207
- super().__init__()
208
- self.hidden_size = hidden_size
209
- self.num_heads = num_heads
210
- self.head_dim = hidden_size // num_heads
211
- if num_key_value_heads is None:
212
- num_key_value_heads = num_heads
213
- self.num_key_value_heads = num_key_value_heads
214
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
215
- self.attn_output_multiplier = attn_output_multiplier
216
- self.max_attn_val = max_attn_val
217
-
218
- if (self.head_dim * self.num_heads) != self.hidden_size:
219
- raise ValueError(
220
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
221
- f" and `num_heads`: {self.num_heads})."
222
- )
223
-
224
- self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
225
- self.k_proj = nn.Linear(
226
- hidden_size, self.num_key_value_heads * self.head_dim, bias=False
227
- )
228
- self.v_proj = nn.Linear(
229
- hidden_size, self.num_key_value_heads * self.head_dim, bias=False
230
- )
231
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, hidden_size, bias=False)
232
-
233
- self.rotary_emb = RotaryEmbedding(
234
- self.head_dim,
235
- max_position_embeddings=max_position_embeddings,
236
- )
237
-
238
- def forward(
239
- self,
240
- hidden_states: torch.Tensor,
241
- attention_mask: Optional[torch.Tensor] = None,
242
- position_ids: Optional[torch.LongTensor] = None,
243
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
244
- output_attentions: bool = False,
245
- use_cache: bool = False,
246
- **kwargs,
247
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
248
- bsz, q_len, _ = hidden_states.size()
249
-
250
- query_states = self.q_proj(hidden_states)
251
- key_states = self.k_proj(hidden_states)
252
- value_states = self.v_proj(hidden_states)
253
-
254
- query_states = query_states.view(
255
- bsz, q_len, self.num_heads, self.head_dim
256
- ).transpose(1, 2)
257
- key_states = key_states.view(
258
- bsz, q_len, self.num_key_value_heads, self.head_dim
259
- ).transpose(1, 2)
260
- value_states = value_states.view(
261
- bsz, q_len, self.num_key_value_heads, self.head_dim
262
- ).transpose(1, 2)
263
-
264
- kv_seq_len = key_states.shape[-2]
265
- if past_key_value is not None:
266
- kv_seq_len += past_key_value[0].shape[-2]
267
-
268
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
269
- query_states, key_states = apply_rotary_pos_emb(
270
- query_states, key_states, cos, sin, position_ids
271
- )
272
-
273
- if past_key_value is not None:
274
- # reuse k, v, self_attention
275
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
276
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
277
-
278
- past_key_value = (key_states, value_states) if use_cache else None
279
-
280
- # repeat k/v heads if n_kv_heads < n_heads
281
- key_states = repeat_kv(key_states, self.num_key_value_groups)
282
- value_states = repeat_kv(value_states, self.num_key_value_groups)
283
-
284
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
285
- torch.float
286
- )
287
- attn_weights = attn_weights * self.attn_output_multiplier
288
- attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val)
289
-
290
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
291
- raise ValueError(
292
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
293
- f" {attn_weights.size()}"
294
- )
295
-
296
- if attention_mask is not None:
297
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
298
- raise ValueError(
299
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
300
- )
301
-
302
- attn_weights = attn_weights + attention_mask
303
-
304
- attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype)
305
- attn_output = torch.matmul(attn_weights, value_states)
306
-
307
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
308
- raise ValueError(
309
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
310
- f" {attn_output.size()}"
311
- )
312
-
313
- attn_output = attn_output.transpose(1, 2).contiguous()
314
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
315
-
316
- attn_output = self.o_proj(attn_output)
317
-
318
- if not output_attentions:
319
- attn_weights = None
320
-
321
- return attn_output, attn_weights, past_key_value
322
-
323
-
324
- class MoeMLP(nn.Module):
325
- def __init__(
326
- self,
327
- hidden_dim: int,
328
- ffn_dim: int,
329
- ) -> None:
330
- super().__init__()
331
- self.w3 = nn.Linear(hidden_dim, ffn_dim, bias=False)
332
- self.w2 = nn.Linear(ffn_dim, hidden_dim, bias=False)
333
- self.w1 = nn.Linear(hidden_dim, ffn_dim, bias=False)
334
- self.act_fn = nn.GELU()
335
-
336
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
337
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
338
- hidden_states
339
- )
340
- current_hidden_states = self.w2(current_hidden_states)
341
- return current_hidden_states
342
-
343
-
344
- class MoeBlock(nn.Module):
345
- def __init__(
346
- self,
347
- hidden_dim: int,
348
- ffn_dim: int,
349
- num_experts: int,
350
- top_k: int,
351
- ) -> None:
352
- super().__init__()
353
- self.num_experts = num_experts
354
- self.top_k = top_k
355
- self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
356
- self.experts = nn.ModuleList(
357
- [MoeMLP(hidden_dim, ffn_dim) for _ in range(num_experts)]
358
- )
359
-
360
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
361
- batch_size, sequence_length, hidden_dim = hidden_states.shape
362
- hidden_states = hidden_states.view(-1, hidden_dim)
363
- # router_logits: (batch * sequence_length, n_experts)
364
- router_logits = self.gate(hidden_states)
365
-
366
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
367
- routing_weights, selected_experts = torch.topk(
368
- routing_weights, self.top_k, dim=-1
369
- )
370
- # we cast back to the input dtype
371
- routing_weights = routing_weights.to(hidden_states.dtype)
372
-
373
- final_hidden_states = torch.zeros(
374
- (batch_size * sequence_length, hidden_dim),
375
- dtype=hidden_states.dtype,
376
- device=hidden_states.device,
377
- )
378
- # One hot encode the selected experts to create an expert mask
379
- # this will be used to easily index which expert is going to be sollicitated
380
- expert_mask = torch.nn.functional.one_hot(
381
- selected_experts, num_classes=self.num_experts
382
- ).permute(2, 1, 0)
383
-
384
- # Loop over all available experts in the model and perform the computation on each expert
385
- for expert_idx in range(self.num_experts):
386
- expert_layer = self.experts[expert_idx]
387
- idx, top_x = torch.where(expert_mask[expert_idx])
388
-
389
- if top_x.shape[0] == 0:
390
- continue
391
-
392
- # in torch it is faster to index using lists than torch tensors
393
- top_x_list = top_x.tolist()
394
- idx_list = idx.tolist()
395
-
396
- # Index the correct hidden states and compute the expert hidden state for
397
- # the current expert. We need to make sure to multiply the output hidden
398
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
399
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
400
- current_hidden_states = (
401
- expert_layer(current_state)
402
- * routing_weights[top_x_list, idx_list, None]
403
- )
404
-
405
- # However `index_add_` only support torch tensors for indexing so we'll use
406
- # the `top_x` tensor here.
407
- final_hidden_states.index_add_(
408
- 0, top_x, current_hidden_states.to(hidden_states.dtype)
409
- )
410
- final_hidden_states = final_hidden_states.reshape(
411
- batch_size, sequence_length, hidden_dim
412
- )
413
- return final_hidden_states, router_logits
414
-
415
-
416
- class DecoderLayer(nn.Module):
417
- def __init__(
418
- self,
419
- hidden_size: int,
420
- intermediate_size: int,
421
- num_heads: int,
422
- num_key_value_heads: int,
423
- num_experts: int,
424
- top_k: int,
425
- max_position_embeddings: int = 2048,
426
- attn_output_multiplier: float = 1.0,
427
- max_attn_val: float = 30.0,
428
- rms_norm_eps: float = 1e-5,
429
- ) -> None:
430
- super().__init__()
431
- self.self_attn = MultiHeadAttention(
432
- hidden_size,
433
- num_heads,
434
- num_key_value_heads,
435
- max_position_embeddings=max_position_embeddings,
436
- attn_output_multiplier=attn_output_multiplier,
437
- max_attn_val=max_attn_val,
438
- )
439
- self.block_sparse_moe = MoeBlock(hidden_size, intermediate_size, num_experts, top_k)
440
- self.pre_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
441
- self.post_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
442
- self.pre_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
443
- self.post_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
444
-
445
- def forward(
446
- self,
447
- hidden_states: torch.Tensor,
448
- attention_mask: Optional[torch.Tensor] = None,
449
- position_ids: Optional[torch.LongTensor] = None,
450
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
451
- output_attentions: Optional[bool] = False,
452
- output_router_logits: Optional[bool] = False,
453
- use_cache: Optional[bool] = False,
454
- **kwargs,
455
- ) -> Tuple[
456
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
457
- ]:
458
- residual = hidden_states
459
- hidden_states = self.pre_attn_norm(hidden_states)
460
- hidden_states, attention_weights, present_key_value = self.self_attn(
461
- hidden_states,
462
- attention_mask=attention_mask,
463
- position_ids=position_ids,
464
- past_key_value=past_key_value,
465
- output_attentions=output_attentions,
466
- use_cache=use_cache,
467
- )
468
- hidden_states = self.post_attn_norm(hidden_states)
469
- hidden_states = residual + hidden_states
470
-
471
- residual = hidden_states
472
- hidden_states = self.pre_moe_norm(hidden_states)
473
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
474
- hidden_states = self.post_moe_norm(hidden_states)
475
- hidden_states = residual + hidden_states
476
-
477
- outputs = (hidden_states,)
478
- if output_attentions:
479
- outputs += (attention_weights,)
480
- if use_cache:
481
- outputs += (present_key_value,)
482
- if output_router_logits:
483
- outputs += (router_logits,)
484
- return outputs
485
-
486
-
487
- class Grok1PretrainedModel(PreTrainedModel):
488
- config_class = Grok1Config
489
- base_model_prefix = "model"
490
- supports_gradient_checkpointing = True
491
- _no_split_modules = ["DecoderLayer"]
492
- _skip_keys_device_placement = "past_key_values"
493
- _supports_flash_attn_2 = False
494
- _supports_cache_class = False
495
-
496
- def _init_weights(self, module) -> None:
497
- if isinstance(module, nn.Linear):
498
- module.weight.data.zero_()
499
- if module.bias is not None:
500
- module.bias.data.zero_()
501
- elif isinstance(module, nn.Embedding):
502
- module.weight.data.zero_()
503
-
504
-
505
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
506
- def _make_causal_mask(
507
- input_ids_shape: torch.Size,
508
- dtype: torch.dtype,
509
- device: torch.device,
510
- past_key_values_length: int = 0,
511
- ):
512
- """
513
- Make causal mask used for bi-directional self-attention.
514
- """
515
- bsz, tgt_len = input_ids_shape
516
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
517
- mask_cond = torch.arange(mask.size(-1), device=device)
518
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
519
- mask = mask.to(dtype)
520
-
521
- if past_key_values_length > 0:
522
- mask = torch.cat(
523
- [
524
- torch.zeros(
525
- tgt_len, past_key_values_length, dtype=dtype, device=device
526
- ),
527
- mask,
528
- ],
529
- dim=-1,
530
- )
531
- return mask[None, None, :, :].expand(
532
- bsz, 1, tgt_len, tgt_len + past_key_values_length
533
- )
534
-
535
-
536
- # Copied from transformers.models.bart.modeling_bart._expand_mask
537
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
538
- """
539
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
540
- """
541
- bsz, src_len = mask.size()
542
- tgt_len = tgt_len if tgt_len is not None else src_len
543
-
544
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
545
-
546
- inverted_mask = 1.0 - expanded_mask
547
-
548
- return inverted_mask.masked_fill(
549
- inverted_mask.to(torch.bool), torch.finfo(dtype).min
550
- )
551
-
552
-
553
- class Grok1Model(Grok1PretrainedModel):
554
- def __init__(self, config: Grok1Config, **kwargs) -> None:
555
- super().__init__(config)
556
- self.padding_idx = config.pad_token_id
557
- self.vocab_size = config.vocab_size
558
- self.embedding_multiplier_scale = config.embedding_multiplier_scale
559
-
560
- self.embed_tokens = nn.Embedding(
561
- config.vocab_size, config.hidden_size, self.padding_idx
562
- )
563
- self.layers = nn.ModuleList(
564
- [
565
- DecoderLayer(
566
- hidden_size=config.hidden_size,
567
- intermediate_size=config.intermediate_size,
568
- num_heads=config.num_attention_heads,
569
- num_key_value_heads=config.num_key_value_heads,
570
- num_experts=config.num_experts,
571
- top_k=config.num_experts_per_tok,
572
- max_position_embeddings=config.max_position_embeddings,
573
- attn_output_multiplier=config.attn_output_multiplier,
574
- max_attn_val=config.max_attn_value,
575
- rms_norm_eps=config.rms_norm_eps,
576
- )
577
- for layer_idx in range(config.num_hidden_layers)
578
- ]
579
- )
580
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
581
- self.gradient_checkpointing = False
582
- self.post_init()
583
-
584
- def get_input_embeddings(self):
585
- return self.embed_tokens
586
-
587
- def set_input_embeddings(self, value):
588
- self.embed_tokens = value
589
-
590
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
591
- def _prepare_decoder_attention_mask(
592
- self, attention_mask, input_shape, inputs_embeds, past_key_values_length
593
- ):
594
- # create causal mask
595
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
596
- combined_attention_mask = None
597
- if input_shape[-1] > 1:
598
- combined_attention_mask = _make_causal_mask(
599
- input_shape,
600
- inputs_embeds.dtype,
601
- device=inputs_embeds.device,
602
- past_key_values_length=past_key_values_length,
603
- )
604
-
605
- if attention_mask is not None:
606
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
607
- expanded_attn_mask = _expand_mask(
608
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
609
- ).to(inputs_embeds.device)
610
- combined_attention_mask = (
611
- expanded_attn_mask
612
- if combined_attention_mask is None
613
- else expanded_attn_mask + combined_attention_mask
614
- )
615
-
616
- return combined_attention_mask
617
-
618
- def forward(
619
- self,
620
- input_ids: torch.LongTensor = None,
621
- attention_mask: Optional[torch.Tensor] = None,
622
- position_ids: Optional[torch.LongTensor] = None,
623
- past_key_values: Optional[List[torch.FloatTensor]] = None,
624
- inputs_embeds: Optional[torch.FloatTensor] = None,
625
- use_cache: Optional[bool] = None,
626
- output_attentions: Optional[bool] = None,
627
- output_hidden_states: Optional[bool] = None,
628
- output_router_logits: Optional[bool] = None,
629
- return_dict: Optional[bool] = None,
630
- ) -> Union[Tuple, MoeModelOutputWithPast]:
631
- output_attentions = (
632
- output_attentions
633
- if output_attentions is not None
634
- else self.config.output_attentions
635
- )
636
- output_hidden_states = (
637
- output_hidden_states
638
- if output_hidden_states is not None
639
- else self.config.output_hidden_states
640
- )
641
- use_cache = use_cache if use_cache is not None else self.config.use_cache
642
-
643
- return_dict = (
644
- return_dict if return_dict is not None else self.config.use_return_dict
645
- )
646
-
647
- # retrieve input_ids and inputs_embeds
648
- if input_ids is not None and inputs_embeds is not None:
649
- raise ValueError(
650
- "You cannot specify both input_ids and inputs_embeds at the same time"
651
- )
652
- elif input_ids is not None:
653
- batch_size, seq_length = input_ids.shape[:2]
654
- elif inputs_embeds is not None:
655
- batch_size, seq_length = inputs_embeds.shape[:2]
656
- else:
657
- raise ValueError("You have to specify either input_ids or inputs_embeds")
658
-
659
- seq_length_with_past = seq_length
660
- past_key_values_length = 0
661
- if past_key_values is not None:
662
- past_key_values_length = past_key_values[0][0].shape[2]
663
- seq_length_with_past = seq_length_with_past + past_key_values_length
664
-
665
- if position_ids is None:
666
- device = input_ids.device if input_ids is not None else inputs_embeds.device
667
- position_ids = torch.arange(
668
- past_key_values_length,
669
- seq_length + past_key_values_length,
670
- dtype=torch.long,
671
- device=device,
672
- )
673
- position_ids = position_ids.unsqueeze(0)
674
-
675
- if inputs_embeds is None:
676
- inputs_embeds = self.embed_tokens(input_ids)
677
- inputs_embeds = inputs_embeds * self.embedding_multiplier_scale
678
-
679
- if HAS_MASK_UTILS:
680
- # 4d mask is passed through the layers
681
- attention_mask = _prepare_4d_causal_attention_mask(
682
- attention_mask,
683
- (batch_size, seq_length),
684
- inputs_embeds,
685
- past_key_values_length,
686
- )
687
- else:
688
- if attention_mask is None:
689
- attention_mask = torch.ones(
690
- (batch_size, seq_length_with_past),
691
- dtype=torch.bool,
692
- device=inputs_embeds.device,
693
- )
694
- attention_mask = self._prepare_decoder_attention_mask(
695
- attention_mask,
696
- (batch_size, seq_length),
697
- inputs_embeds,
698
- past_key_values_length,
699
- )
700
-
701
- # embed positions
702
- hidden_states = inputs_embeds
703
-
704
- if self.gradient_checkpointing and self.training:
705
- if use_cache:
706
- logger.warning_once(
707
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
708
- )
709
- use_cache = False
710
-
711
- # decoder layers
712
- all_hidden_states = () if output_hidden_states else None
713
- all_self_attns = () if output_attentions else None
714
- all_router_logits = () if output_router_logits else None
715
- next_decoder_cache = () if use_cache else None
716
-
717
- for idx, decoder_layer in enumerate(self.layers):
718
- if output_hidden_states:
719
- all_hidden_states += (hidden_states,)
720
-
721
- past_key_value = (
722
- past_key_values[idx] if past_key_values is not None else None
723
- )
724
-
725
- if self.gradient_checkpointing and self.training:
726
-
727
- def create_custom_forward(module):
728
- def custom_forward(*inputs):
729
- # None for past_key_value
730
- return module(*inputs, past_key_value, output_attentions)
731
-
732
- return custom_forward
733
-
734
- layer_outputs = torch.utils.checkpoint.checkpoint(
735
- create_custom_forward(decoder_layer),
736
- hidden_states,
737
- attention_mask,
738
- position_ids,
739
- )
740
- else:
741
- layer_outputs = decoder_layer(
742
- hidden_states,
743
- attention_mask=attention_mask,
744
- position_ids=position_ids,
745
- past_key_value=past_key_value,
746
- output_attentions=output_attentions,
747
- use_cache=use_cache,
748
- )
749
-
750
- hidden_states = layer_outputs[0]
751
-
752
- if use_cache:
753
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
754
-
755
- if output_attentions:
756
- all_self_attns += (layer_outputs[1],)
757
-
758
- if output_router_logits:
759
- all_router_logits += (layer_outputs[-1],)
760
-
761
- hidden_states = self.norm(hidden_states)
762
-
763
- # add hidden states from the last decoder layer
764
- if output_hidden_states:
765
- all_hidden_states += (hidden_states,)
766
- next_cache = next_decoder_cache if use_cache else None
767
-
768
- if not return_dict:
769
- return tuple(
770
- v
771
- for v in [
772
- hidden_states,
773
- next_cache,
774
- all_hidden_states,
775
- all_self_attns,
776
- all_router_logits,
777
- ]
778
- if v is not None
779
- )
780
- return MoeModelOutputWithPast(
781
- last_hidden_state=hidden_states,
782
- past_key_values=next_cache,
783
- hidden_states=all_hidden_states,
784
- attentions=all_self_attns,
785
- router_logits=all_router_logits,
786
- )
787
-
788
-
789
- class Grok1ModelForCausalLM(Grok1PretrainedModel):
790
- _tied_weights_keys = ["lm_head.weight"]
791
-
792
- def __init__(self, config: Grok1Config, **kwargs):
793
- super().__init__(config)
794
- self.model = Grok1Model(config)
795
- self.vocab_size = config.vocab_size
796
- self.output_multiplier_scale = config.output_multiplier_scale
797
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
798
- self.router_aux_loss_coef = config.router_aux_loss_coef
799
- self.num_experts = config.num_experts
800
- self.num_experts_per_tok = config.num_experts_per_tok
801
- self.post_init()
802
-
803
- def get_input_embeddings(self):
804
- return self.model.embed_tokens
805
-
806
- def set_input_embeddings(self, value):
807
- self.model.embed_tokens = value
808
-
809
- def get_output_embeddings(self):
810
- return self.lm_head
811
-
812
- def set_output_embeddings(self, new_embeddings):
813
- self.lm_head = new_embeddings
814
-
815
- def set_decoder(self, decoder):
816
- self.model = decoder
817
-
818
- def get_decoder(self):
819
- return self.model
820
-
821
- def forward(
822
- self,
823
- input_ids: torch.LongTensor = None,
824
- attention_mask: Optional[torch.Tensor] = None,
825
- position_ids: Optional[torch.LongTensor] = None,
826
- past_key_values: Optional[List[torch.FloatTensor]] = None,
827
- inputs_embeds: Optional[torch.FloatTensor] = None,
828
- labels: Optional[torch.LongTensor] = None,
829
- use_cache: Optional[bool] = None,
830
- output_attentions: Optional[bool] = None,
831
- output_hidden_states: Optional[bool] = None,
832
- output_router_logits: Optional[bool] = None,
833
- return_dict: Optional[bool] = None,
834
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
835
- output_attentions = (
836
- output_attentions
837
- if output_attentions is not None
838
- else self.config.output_attentions
839
- )
840
- output_router_logits = (
841
- output_router_logits
842
- if output_router_logits is not None
843
- else self.config.output_router_logits
844
- )
845
-
846
- output_hidden_states = (
847
- output_hidden_states
848
- if output_hidden_states is not None
849
- else self.config.output_hidden_states
850
- )
851
- return_dict = (
852
- return_dict if return_dict is not None else self.config.use_return_dict
853
- )
854
-
855
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
856
- outputs = self.model(
857
- input_ids=input_ids,
858
- attention_mask=attention_mask,
859
- position_ids=position_ids,
860
- past_key_values=past_key_values,
861
- inputs_embeds=inputs_embeds,
862
- use_cache=use_cache,
863
- output_attentions=output_attentions,
864
- output_hidden_states=output_hidden_states,
865
- output_router_logits=output_router_logits,
866
- return_dict=return_dict,
867
- )
868
-
869
- hidden_states = outputs[0]
870
- logits = self.lm_head(hidden_states)
871
- logits = logits * self.output_multiplier_scale
872
- logits = logits.float()
873
-
874
- loss = None
875
- if labels is not None:
876
- # Shift so that tokens < n predict n
877
- shift_logits = logits[..., :-1, :].contiguous()
878
- shift_labels = labels[..., 1:].contiguous()
879
- # Flatten the tokens
880
- loss_fct = nn.CrossEntropyLoss()
881
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
882
- shift_labels = shift_labels.view(-1)
883
- # Enable model parallelism
884
- shift_labels = shift_labels.to(shift_logits.device)
885
- loss = loss_fct(shift_logits, shift_labels)
886
-
887
- aux_loss = None
888
- if output_router_logits:
889
- aux_loss = load_balancing_loss_func(
890
- outputs.router_logits if return_dict else outputs[-1],
891
- self.num_experts,
892
- self.num_experts_per_tok,
893
- )
894
- if labels is not None:
895
- loss += self.router_aux_loss_coef * aux_loss
896
-
897
- if not return_dict:
898
- output = (logits,) + outputs[1:]
899
- if output_router_logits:
900
- output = (aux_loss,) + output
901
- return (loss,) + output if loss is not None else output
902
-
903
- return MoeCausalLMOutputWithPast(
904
- loss=loss,
905
- aux_loss=aux_loss,
906
- logits=logits,
907
- past_key_values=outputs.past_key_values,
908
- hidden_states=outputs.hidden_states,
909
- attentions=outputs.attentions,
910
- router_logits=outputs.router_logits,
911
- )
912
-
913
- def prepare_inputs_for_generation(
914
- self,
915
- input_ids,
916
- past_key_values=None,
917
- attention_mask=None,
918
- inputs_embeds=None,
919
- **kwargs,
920
- ):
921
- if past_key_values:
922
- input_ids = input_ids[:, -1:]
923
-
924
- position_ids = kwargs.get("position_ids", None)
925
- if attention_mask is not None and position_ids is None:
926
- # create position_ids on the fly for batch generation
927
- position_ids = attention_mask.long().cumsum(-1) - 1
928
- position_ids.masked_fill_(attention_mask == 0, 1)
929
- if past_key_values:
930
- position_ids = position_ids[:, -1].unsqueeze(-1)
931
-
932
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
933
- if inputs_embeds is not None and past_key_values is None:
934
- model_inputs = {"inputs_embeds": inputs_embeds}
935
- else:
936
- model_inputs = {"input_ids": input_ids}
937
-
938
- model_inputs.update(
939
- {
940
- "position_ids": position_ids,
941
- "past_key_values": past_key_values,
942
- "use_cache": kwargs.get("use_cache"),
943
- "attention_mask": attention_mask,
944
- }
945
- )
946
- return model_inputs