File size: 4,689 Bytes
05d640e
 
8ef2cad
05d640e
fc5dc52
05d640e
fc5dc52
05d640e
 
 
 
 
 
 
 
 
 
8ef2cad
05d640e
8ef2cad
05d640e
 
8ef2cad
 
05d640e
 
 
 
8ef2cad
fc5dc52
8ef2cad
 
 
 
 
 
 
 
 
 
 
 
 
 
05d640e
 
8ef2cad
 
 
 
 
 
 
05d640e
 
8ef2cad
 
05d640e
 
8ef2cad
 
 
 
 
 
 
 
 
 
05d640e
 
 
8ef2cad
05d640e
 
8ef2cad
 
05d640e
fc5dc52
05d640e
8ef2cad
05d640e
8ef2cad
05d640e
 
 
 
 
8ef2cad
05d640e
 
 
fc5dc52
 
 
 
 
 
 
 
 
8ef2cad
 
fc5dc52
 
 
 
 
 
 
 
 
 
 
 
 
 
05d640e
 
 
 
 
 
fc5dc52
05d640e
 
fc5dc52
 
05d640e
 
 
 
fc5dc52
 
05d640e
 
 
 
 
 
 
fc5dc52
 
05d640e
 
fc5dc52
 
 
05d640e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn as nn

from torch.nn import functional as F
from bitblas.cache import OperatorCache

from .layers import layer_norm, mlp, Linear
from .rope import apply_rotary_emb, precompute_freqs_cis
from .config import TextConfig


def text_encoder(input_ids: torch.Tensor, w: nn.Module):
    return F.embedding(input_ids, w.wte)


def attn(
    x: torch.Tensor,
    w: nn.Module,
    freqs_cis: torch.Tensor,
    kv_cache: nn.Module,
    attn_mask: torch.Tensor,
    n_heads: int,
    n_kv_heads: int,
    position_ids: torch.Tensor,
):
    bsz, q_len, d_model = x.shape
    head_dim = d_model // n_heads

    qkv_out = w.qkv(x)  # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)

    q_dim = n_heads * head_dim
    kv_dim = n_kv_heads * head_dim

    q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
    k = (
        qkv_out[..., q_dim : q_dim + kv_dim]
        .view(bsz, q_len, n_kv_heads, head_dim)
        .transpose(1, 2)
    )
    v = (
        qkv_out[..., q_dim + kv_dim :]
        .view(bsz, q_len, n_kv_heads, head_dim)
        .transpose(1, 2)
    )

    q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
    k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)

    if kv_cache is not None:
        k, v = kv_cache.update(position_ids, k, v)

    out = F.scaled_dot_product_attention(
        q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
    )
    out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
    out = w.proj(out)
    return out


def text_decoder(
    x: torch.Tensor,
    w: nn.Module,
    attn_mask: torch.Tensor,
    position_ids: torch.Tensor,
    config: TextConfig,
):
    for i, block in enumerate(w.blocks):
        l_in = layer_norm(x, block.ln)
        l_attn = attn(
            l_in,
            block.attn,
            freqs_cis=w.freqs_cis,
            kv_cache=block.kv_cache,
            attn_mask=attn_mask,
            n_heads=config.n_heads,
            n_kv_heads=config.n_kv_heads,
            position_ids=position_ids,
        )

        l_mlp = mlp(l_in, block.mlp)
        x = x + l_attn + l_mlp

    return x


def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
    hidden_BC = hidden_BTC[:, -1, :]
    hidden_BC = layer_norm(hidden_BC, w.post_ln)
    logits = w.lm_head(hidden_BC)
    return logits


def build_text_model(
    config: TextConfig,
    linear_dtype: torch.dtype = torch.float16,
    layernorm_dtype: torch.dtype = torch.float16,
) -> nn.Module:
    # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm
    print(
        "Initializing quantized backend. This only has to run once, but may take a few minutes."
    )
    qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))

    group_size = None
    if linear_dtype == torch.int8:

        group_size = config.group_size

    def create_linear(in_features, out_features, dtype=linear_dtype):
        # factory function for creating Linear layers so we dont have to pass everything again and again
        return Linear(
            in_features=in_features,
            out_features=out_features,
            dtype=dtype,
            group_size=group_size,
        )

    text = nn.ModuleDict(
        {
            "blocks": nn.ModuleList(
                [
                    nn.ModuleDict(
                        {
                            "ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
                            "attn": nn.ModuleDict(
                                {
                                    "qkv": create_linear(config.dim, qkv_dim),
                                    "proj": create_linear(config.dim, config.dim),
                                }
                            ),
                            "mlp": nn.ModuleDict(
                                {
                                    "fc1": create_linear(config.dim, config.ff_dim),
                                    "fc2": create_linear(config.ff_dim, config.dim),
                                }
                            ),
                        }
                    )
                    for _ in range(config.n_layers)
                ]
            ),
            "post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
            "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=layernorm_dtype),
        }
    )
    text.wte = nn.Parameter(
        torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)
    )
    text.register_buffer(
        "freqs_cis",
        precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
        persistent=False,
    )

    return text