|
|
|
|
|
|
|
import torch |
|
from datasets import load_dataset |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
Trainer, |
|
TrainingArguments, |
|
DataCollatorForLanguageModeling |
|
) |
|
|
|
|
|
print("📥 Loading dataset...") |
|
dataset = load_dataset("codeparrot/codeparrot-clean", split="train[:1%]") |
|
|
|
|
|
model_name = "distilgpt2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def tokenize_fn(examples): |
|
return tokenizer(examples["content"], truncation=True, padding="max_length", max_length=128) |
|
|
|
print("🔤 Tokenizing dataset...") |
|
tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["content"]) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
|
|
print("⚙️ Loading model...") |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./mini_gpt_code", |
|
overwrite_output_dir=True, |
|
evaluation_strategy="no", |
|
per_device_train_batch_size=2, |
|
num_train_epochs=1, |
|
save_strategy="epoch", |
|
logging_dir="./logs", |
|
save_safetensors=True, |
|
fp16=torch.cuda.is_available(), |
|
push_to_hub=False |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator |
|
) |
|
|
|
|
|
print("🚀 Training started...") |
|
trainer.train() |
|
|
|
|
|
save_path = "./mini_gpt_code_safetensors" |
|
trainer.save_model(save_path) |
|
tokenizer.save_pretrained(save_path) |
|
print(f"✅ Training complete. Model saved at {save_path}") |
|
|
|
|
|
print("💻 Generating Python code...") |
|
prompt = "Write a Python function to calculate factorial:\n" |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=100, |
|
temperature=0.7, |
|
do_sample=True, |
|
top_p=0.9 |
|
) |
|
|
|
print("\nGenerated Code:\n") |
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|