File size: 9,590 Bytes
b952d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aba9604
 
 
 
 
 
 
 
 
b952d19
aba9604
 
 
b952d19
aba9604
 
 
 
b952d19
aba9604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b952d19
 
 
aba9604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b952d19
 
aba9604
 
 
 
 
 
b952d19
aba9604
 
 
 
b952d19
aba9604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b952d19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import regex as re
import collections
import os
import random
from tqdm import tqdm
from transformers import PreTrainedModel
from transformers import PretrainedConfig

class ArabicGPTConfig(PretrainedConfig):
    model_type = "arabic-gpt"

    def __init__(self,
                 vocab_size=32000,
                 max_seq_len=1024,
                 embed_dim=768,
                 num_heads=12,
                 num_layers=12,
                 ff_dim=3072,
                 dropout=0.1,
                 **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.ff_dim = ff_dim
        self.dropout = dropout
        self.tie_word_embeddings = True



class ArabicGPTModel(PreTrainedModel):
    config_class = ArabicGPTConfig

    def __init__(self, config: ArabicGPTConfig):
        super().__init__(config)
        self.model = ArabicGPT(
            vocab_size=config.vocab_size,
            max_seq_len=config.max_seq_len,
            embed_dim=config.embed_dim,
            num_heads=config.num_heads,
            num_layers=config.num_layers,
            ff_dim=config.ff_dim,
            dropout=config.dropout,
        )

    def forward(self, x):
        return self.model(x)

    def generate(self, prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9):
        return self.model.generate(prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9)

    def get_input_embeddings(self):
        return self.model.token_embedding

    def set_input_embeddings(self, new_embeddings):
        self.model.token_embedding = new_embeddings
    
    def get_output_embeddings(self):
        return self.model.lm_head
    
    def tie_weights(self):
        self.model.lm_head.weight = self.model.token_embedding.weight


# Part 2: GPT Model Implementation
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim, mask=True):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
        self.mask = mask
        self.scale = head_dim ** -0.5

    def forward(self, x):
        # x shape: (batch, seq_len, embed_dim)
        batch_size, seq_len, _ = x.shape

        # Linear projections
        q = self.q(x)  # (batch, seq_len, head_dim)
        k = self.k(x)  # (batch, seq_len, head_dim)
        v = self.v(x)  # (batch, seq_len, head_dim)

        # Compute attention scores
        attn = torch.bmm(q, k.transpose(1, 2)) * self.scale  # (batch, seq_len, seq_len)

        # Apply causal mask for decoder
        if self.mask:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
            attn.masked_fill_(mask, float('-inf'))

        # Apply softmax and get weighted values
        attn = F.softmax(attn, dim=-1)
        output = torch.bmm(attn, v)  # (batch, seq_len, head_dim)

        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, mask=True):
        super().__init__()
        self.heads = nn.ModuleList([
            AttentionHead(embed_dim, embed_dim // num_heads, mask)
            for _ in range(num_heads)
        ])
        self.linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        # Concatenate outputs from all heads
        heads_output = torch.cat([head(x) for head in self.heads], dim=-1)
        # Final linear projection
        output = self.linear(heads_output)
        return output

class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, embed_dim)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Self-attention with residual connection and layer norm
        attn_output = self.attn(self.norm1(x))
        x = x + self.dropout(attn_output)

        # Feed-forward with residual connection and layer norm
        ff_output = self.ff(self.norm2(x))
        x = x + self.dropout(ff_output)

        return x

class ArabicGPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len=1024, embed_dim=768, num_heads=12,
                 num_layers=12, ff_dim=3072, dropout=0.1):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_seq_len, embed_dim)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])

        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)

        # Language model head
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)

        # Share weights between token embedding and LM head
        # self.lm_head.weight = self.token_embedding.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, x):
        # x shape: (batch, seq_len)
        batch_size, seq_len = x.shape

        # Get positions
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)

        # Get token and position embeddings
        token_embed = self.token_embedding(x)
        pos_embed = self.position_embedding(positions)

        # Combine embeddings
        x = token_embed + pos_embed

        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)

        # Apply final layer norm
        x = self.norm(x)

        # Get logits
        logits = self.lm_head(x)

        return logits

    def generate(self, prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9):
        """Generate text using the model."""
        self.eval()
        with torch.no_grad():
            # Convert prompt to tensor if needed
            if not isinstance(prompt_ids, torch.Tensor):
                prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)

            # Move to device and add batch dimension if needed
            if len(prompt_ids.shape) == 1:
                prompt_ids = prompt_ids.unsqueeze(0)
            prompt_ids = prompt_ids.to(next(self.parameters()).device)

            # Start with prompt
            generated_ids = prompt_ids.clone()

            # Generate new tokens
            for _ in range(max_new_tokens):
                # Take last context up to max sequence length
                input_ids = generated_ids[:, -self.max_seq_len:]

                # Get logits for next token
                logits = self(input_ids)
                next_token_logits = logits[:, -1, :]

                # Apply temperature
                if temperature > 0:
                    next_token_logits = next_token_logits / temperature

                # Apply top-k filtering
                if top_k > 0:
                    indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                    next_token_logits[indices_to_remove] = float('-inf')

                # Apply top-p (nucleus) filtering
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                    # Remove tokens with cumulative probability above the threshold
                    sorted_indices_to_remove = cumulative_probs > top_p
                    # Shift the indices to the right to keep the first token above threshold
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0

                    indices_to_remove = sorted_indices[sorted_indices_to_remove]
                    next_token_logits[:, indices_to_remove] = float('-inf')

                # Sample next token
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                # Append next token to generated
                generated_ids = torch.cat([generated_ids, next_token], dim=1)

                # Stop if EOS token
                if next_token.item() == 2:  # Standard EOS token id
                    break

            return generated_ids