FaseehGPT / modeling_arabic-gpt.py
codewithdark's picture
Update modeling_arabic-gpt.py
aba9604 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import regex as re
import collections
import os
import random
from tqdm import tqdm
from transformers import PreTrainedModel
from transformers import PretrainedConfig
class ArabicGPTConfig(PretrainedConfig):
model_type = "arabic-gpt"
def __init__(self,
vocab_size=32000,
max_seq_len=1024,
embed_dim=768,
num_heads=12,
num_layers=12,
ff_dim=3072,
dropout=0.1,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.ff_dim = ff_dim
self.dropout = dropout
self.tie_word_embeddings = True
class ArabicGPTModel(PreTrainedModel):
config_class = ArabicGPTConfig
def __init__(self, config: ArabicGPTConfig):
super().__init__(config)
self.model = ArabicGPT(
vocab_size=config.vocab_size,
max_seq_len=config.max_seq_len,
embed_dim=config.embed_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
ff_dim=config.ff_dim,
dropout=config.dropout,
)
def forward(self, x):
return self.model(x)
def generate(self, prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9):
return self.model.generate(prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9)
def get_input_embeddings(self):
return self.model.token_embedding
def set_input_embeddings(self, new_embeddings):
self.model.token_embedding = new_embeddings
def get_output_embeddings(self):
return self.model.lm_head
def tie_weights(self):
self.model.lm_head.weight = self.model.token_embedding.weight
# Part 2: GPT Model Implementation
class AttentionHead(nn.Module):
def __init__(self, embed_dim, head_dim, mask=True):
super().__init__()
self.q = nn.Linear(embed_dim, head_dim)
self.k = nn.Linear(embed_dim, head_dim)
self.v = nn.Linear(embed_dim, head_dim)
self.mask = mask
self.scale = head_dim ** -0.5
def forward(self, x):
# x shape: (batch, seq_len, embed_dim)
batch_size, seq_len, _ = x.shape
# Linear projections
q = self.q(x) # (batch, seq_len, head_dim)
k = self.k(x) # (batch, seq_len, head_dim)
v = self.v(x) # (batch, seq_len, head_dim)
# Compute attention scores
attn = torch.bmm(q, k.transpose(1, 2)) * self.scale # (batch, seq_len, seq_len)
# Apply causal mask for decoder
if self.mask:
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
attn.masked_fill_(mask, float('-inf'))
# Apply softmax and get weighted values
attn = F.softmax(attn, dim=-1)
output = torch.bmm(attn, v) # (batch, seq_len, head_dim)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, mask=True):
super().__init__()
self.heads = nn.ModuleList([
AttentionHead(embed_dim, embed_dim // num_heads, mask)
for _ in range(num_heads)
])
self.linear = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
# Concatenate outputs from all heads
heads_output = torch.cat([head(x) for head in self.heads], dim=-1)
# Final linear projection
output = self.linear(heads_output)
return output
class FeedForward(nn.Module):
def __init__(self, embed_dim, ff_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, embed_dim)
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(embed_dim, num_heads)
self.ff = FeedForward(embed_dim, ff_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Self-attention with residual connection and layer norm
attn_output = self.attn(self.norm1(x))
x = x + self.dropout(attn_output)
# Feed-forward with residual connection and layer norm
ff_output = self.ff(self.norm2(x))
x = x + self.dropout(ff_output)
return x
class ArabicGPT(nn.Module):
def __init__(self, vocab_size, max_seq_len=1024, embed_dim=768, num_heads=12,
num_layers=12, ff_dim=3072, dropout=0.1):
super().__init__()
self.max_seq_len = max_seq_len
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
# Final layer norm
self.norm = nn.LayerNorm(embed_dim)
# Language model head
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
# Share weights between token embedding and LM head
# self.lm_head.weight = self.token_embedding.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, x):
# x shape: (batch, seq_len)
batch_size, seq_len = x.shape
# Get positions
positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
# Get token and position embeddings
token_embed = self.token_embedding(x)
pos_embed = self.position_embedding(positions)
# Combine embeddings
x = token_embed + pos_embed
# Apply transformer blocks
for block in self.blocks:
x = block(x)
# Apply final layer norm
x = self.norm(x)
# Get logits
logits = self.lm_head(x)
return logits
def generate(self, prompt_ids, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9):
"""Generate text using the model."""
self.eval()
with torch.no_grad():
# Convert prompt to tensor if needed
if not isinstance(prompt_ids, torch.Tensor):
prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)
# Move to device and add batch dimension if needed
if len(prompt_ids.shape) == 1:
prompt_ids = prompt_ids.unsqueeze(0)
prompt_ids = prompt_ids.to(next(self.parameters()).device)
# Start with prompt
generated_ids = prompt_ids.clone()
# Generate new tokens
for _ in range(max_new_tokens):
# Take last context up to max sequence length
input_ids = generated_ids[:, -self.max_seq_len:]
# Get logits for next token
logits = self(input_ids)
next_token_logits = logits[:, -1, :]
# Apply temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
# Apply top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[:, indices_to_remove] = float('-inf')
# Sample next token
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append next token to generated
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# Stop if EOS token
if next_token.item() == 2: # Standard EOS token id
break
return generated_ids