|
import os
|
|
import json
|
|
import gc
|
|
import logging
|
|
import torch
|
|
import pickle
|
|
from torch.utils.data import Dataset
|
|
from transformers import (
|
|
AutoModelForSeq2SeqLM,
|
|
AutoTokenizer,
|
|
TrainingArguments,
|
|
Trainer,
|
|
TrainerCallback,
|
|
DataCollatorForSeq2Seq,
|
|
)
|
|
|
|
|
|
MAX_ITEMS = None
|
|
MAX_LENGTH = 256
|
|
PER_DEVICE_BATCH = 1
|
|
GRAD_ACC_STEPS = 16
|
|
LEARNING_RATE = 5e-5
|
|
NUM_TRAIN_EPOCHS = 1
|
|
WARMUP_STEPS = 200
|
|
FP16_TRAINING = False
|
|
OPTIMIZER_CHOICE = "adamw_8bit"
|
|
MAX_GRAD_NORM_CLIP = 0.0
|
|
GRADIENT_CHECKPOINTING = True
|
|
LOGGING_STEPS = 50
|
|
SAVE_STEPS = 1000
|
|
EVAL_STEPS = 500
|
|
SAVE_TOTAL_LIMIT = 20
|
|
FIXED_PROMPT_FOR_GENERATION = "Create stable diffusion metadata based on the given english description. a futuristic city"
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(levelname)s — %(name)s — %(message)s")
|
|
log = logging.getLogger(__name__)
|
|
|
|
class SDPromptDataset(Dataset):
|
|
def __init__(self, raw_data_list, tokenizer, max_length, dataset_type="train", cache_dir="cache"):
|
|
self.raw_data = raw_data_list
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
self.dataset_type = dataset_type
|
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
cache_file = os.path.join(cache_dir, f"{dataset_type}_{len(raw_data_list)}_{max_length}.pkl")
|
|
|
|
if os.path.exists(cache_file):
|
|
log.info(f"Loading cached {dataset_type} dataset from {cache_file}")
|
|
with open(cache_file, 'rb') as f:
|
|
self.examples = pickle.load(f)
|
|
log.info(f"Loaded {len(self.examples)} cached examples for {dataset_type}")
|
|
else:
|
|
log.info(f"Tokenizing {len(raw_data_list)} samples for {dataset_type} with {type(tokenizer).__name__}...")
|
|
self.examples = []
|
|
|
|
for i, item in enumerate(raw_data_list):
|
|
if i > 0 and (i % 5000 == 0 or i == len(raw_data_list) - 1):
|
|
log.info(f"Tokenized {i+1} / {len(raw_data_list)} samples for {dataset_type}")
|
|
|
|
instruction = item.get("instruction", "")
|
|
output = item.get("output", "")
|
|
|
|
input_encoding = tokenizer(
|
|
instruction, max_length=max_length, padding="max_length",
|
|
truncation=True, return_tensors="pt",
|
|
)
|
|
|
|
if self.dataset_type == "train" or (self.dataset_type == "eval" and output):
|
|
target_encoding = tokenizer(
|
|
output, max_length=max_length, padding="max_length",
|
|
truncation=True, return_tensors="pt",
|
|
)
|
|
labels = target_encoding["input_ids"].squeeze()
|
|
labels[labels == tokenizer.pad_token_id] = -100
|
|
else:
|
|
labels = None
|
|
|
|
example_data = {
|
|
"input_ids": input_encoding["input_ids"].squeeze(),
|
|
"attention_mask": input_encoding["attention_mask"].squeeze(),
|
|
}
|
|
if labels is not None:
|
|
example_data["labels"] = labels
|
|
|
|
self.examples.append(example_data)
|
|
|
|
log.info(f"Tokenization complete for {dataset_type}. Saving cache to {cache_file}")
|
|
with open(cache_file, 'wb') as f:
|
|
pickle.dump(self.examples, f)
|
|
log.info(f"Cache saved successfully")
|
|
|
|
def __len__(self):
|
|
return len(self.examples)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.examples[idx]
|
|
|
|
def get_raw_example(self, idx):
|
|
return self.raw_data[idx]
|
|
|
|
def load_and_split_json_data(data_path, max_items_from_config=None):
|
|
log.info(f"Loading data from {data_path}...")
|
|
if not os.path.exists(data_path):
|
|
log.error(f"Data file not found: {data_path}")
|
|
raise FileNotFoundError(f"Data file not found: {data_path}")
|
|
|
|
with open(data_path, "r", encoding="utf-8") as f:
|
|
all_data = json.load(f)
|
|
log.info(f"Successfully loaded {len(all_data)} total items from JSON.")
|
|
|
|
if max_items_from_config is not None and max_items_from_config > 0:
|
|
num_to_take = min(max_items_from_config, len(all_data))
|
|
log.info(f"Keeping the first {num_to_take} samples as per MAX_ITEMS config.")
|
|
all_data = all_data[:num_to_take]
|
|
else:
|
|
log.info("Using the full dataset.")
|
|
|
|
if not all_data:
|
|
log.error("No data loaded or remaining.")
|
|
raise ValueError("No data to process.")
|
|
|
|
if len(all_data) < 20:
|
|
split_idx = max(1, int(0.5 * len(all_data)))
|
|
log.warning(f"Dataset very small ({len(all_data)} items). Adjusting split.")
|
|
else:
|
|
split_idx = int(0.9 * len(all_data))
|
|
split_idx = max(1, split_idx)
|
|
|
|
train_data = all_data[:split_idx]
|
|
val_data = all_data[split_idx:]
|
|
|
|
if not val_data and train_data:
|
|
val_data = [train_data[-1]]
|
|
log.warning("Validation set was empty after split, using one sample from training data for validation.")
|
|
if len(train_data) > 1:
|
|
train_data = train_data[:-1]
|
|
|
|
val_data = val_data[:min(len(val_data), 2000)] if val_data else None
|
|
|
|
if not train_data:
|
|
log.error("Training data empty.")
|
|
raise ValueError("Training data empty.")
|
|
|
|
log.info(f"Train samples: {len(train_data)}, Validation samples: {len(val_data) if val_data else 0}")
|
|
return train_data, val_data
|
|
|
|
def find_latest_checkpoint(output_dir):
|
|
if not os.path.isdir(output_dir):
|
|
return None
|
|
|
|
checkpoints = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, d))]
|
|
if not checkpoints:
|
|
return None
|
|
|
|
checkpoints.sort(key=lambda x: int(x.split('-')[-1]))
|
|
latest_checkpoint = os.path.join(output_dir, checkpoints[-1])
|
|
|
|
if os.path.exists(os.path.join(latest_checkpoint, "pytorch_model.bin")) or os.path.exists(os.path.join(latest_checkpoint, "model.safetensors")):
|
|
return latest_checkpoint
|
|
|
|
return None
|
|
|
|
def clear_cuda_cache():
|
|
log.info("Clearing CUDA cache...")
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
def generate_and_log_fixed_sample(model, tokenizer, prompt_text, device, log_prefix="Sample"):
|
|
log.info(f"\n--- {log_prefix} Generation ---")
|
|
log.info(f"Input Prompt: {prompt_text}")
|
|
model.eval()
|
|
inputs = tokenizer(prompt_text, return_tensors="pt", max_length=MAX_LENGTH, truncation=True)
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
**inputs, max_length=MAX_LENGTH + 50,
|
|
num_beams=5, early_stopping=True, no_repeat_ngram_size=3,
|
|
temperature=0.7, top_k=50, top_p=0.95
|
|
)
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
log.info(f"Generated Output: {generated_text}")
|
|
log.info(f"--- End {log_prefix} Generation ---\n")
|
|
|
|
class ShowFixedEvalSampleCallback(TrainerCallback):
|
|
def __init__(self, tokenizer, prompt_text):
|
|
self.tokenizer = tokenizer
|
|
self.prompt_text = prompt_text
|
|
|
|
def on_evaluate(self, args, state, control, model=None, **kwargs):
|
|
if model is None:
|
|
return
|
|
device = next(model.parameters()).device
|
|
generate_and_log_fixed_sample(model, self.tokenizer, self.prompt_text, device, log_prefix="Evaluation Callback Sample")
|
|
model.train()
|
|
|
|
def Train(model_id: str, output_dir: str, data_path: str):
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
clear_cuda_cache()
|
|
|
|
|
|
resume_from_checkpoint = find_latest_checkpoint(output_dir)
|
|
if resume_from_checkpoint:
|
|
log.info(f"Found checkpoint to resume from: {resume_from_checkpoint}")
|
|
else:
|
|
log.info("No existing checkpoint found, starting fresh training")
|
|
|
|
log.info(f"Attempting to load MyT5Tokenizer for {model_id} (trust_remote_code=True).")
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
log.info(f"Successfully loaded tokenizer: {type(tokenizer).__name__}")
|
|
except Exception as e:
|
|
log.error(f"Failed to load tokenizer for {model_id} (trust_remote_code=True): {e}")
|
|
return
|
|
|
|
train_raw_data, eval_raw_data = load_and_split_json_data(data_path, max_items_from_config=MAX_ITEMS)
|
|
if not train_raw_data:
|
|
return
|
|
|
|
train_dataset = SDPromptDataset(train_raw_data, tokenizer, MAX_LENGTH, dataset_type="train")
|
|
eval_dataset = SDPromptDataset(eval_raw_data, tokenizer, MAX_LENGTH, dataset_type="eval") if eval_raw_data else None
|
|
|
|
log.info(f"Loading model: {model_id}")
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.float16 if FP16_TRAINING else torch.float32,
|
|
device_map="auto",
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
|
|
if GRADIENT_CHECKPOINTING:
|
|
model.gradient_checkpointing_enable()
|
|
log.info("Grad-ckpt enabled.")
|
|
|
|
if OPTIMIZER_CHOICE == "adamw_8bit":
|
|
try:
|
|
import bitsandbytes
|
|
log.info(f"bitsandbytes version: {bitsandbytes.__version__} imported for adamw_8bit.")
|
|
except ImportError:
|
|
log.error("bitsandbytes not installed, required for optim='adamw_8bit'. Install: pip install bitsandbytes")
|
|
return
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=output_dir,
|
|
per_device_train_batch_size=PER_DEVICE_BATCH,
|
|
per_device_eval_batch_size=PER_DEVICE_BATCH * 2,
|
|
gradient_accumulation_steps=GRAD_ACC_STEPS,
|
|
learning_rate=LEARNING_RATE,
|
|
num_train_epochs=NUM_TRAIN_EPOCHS,
|
|
warmup_steps=WARMUP_STEPS,
|
|
logging_steps=LOGGING_STEPS,
|
|
save_strategy="steps",
|
|
save_steps=SAVE_STEPS,
|
|
eval_strategy="steps" if eval_dataset else "no",
|
|
eval_steps=EVAL_STEPS if eval_dataset else None,
|
|
save_total_limit=SAVE_TOTAL_LIMIT,
|
|
load_best_model_at_end=True if eval_dataset else False,
|
|
fp16=FP16_TRAINING,
|
|
optim=OPTIMIZER_CHOICE,
|
|
max_grad_norm=MAX_GRAD_NORM_CLIP,
|
|
gradient_checkpointing=GRADIENT_CHECKPOINTING,
|
|
group_by_length=True,
|
|
lr_scheduler_type="cosine",
|
|
weight_decay=0.01,
|
|
report_to="none",
|
|
)
|
|
|
|
fixed_sample_callback = ShowFixedEvalSampleCallback(tokenizer=tokenizer, prompt_text=FIXED_PROMPT_FOR_GENERATION)
|
|
callbacks_to_use = [fixed_sample_callback] if eval_dataset else []
|
|
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest")
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
data_collator=data_collator,
|
|
tokenizer=tokenizer,
|
|
callbacks=callbacks_to_use
|
|
)
|
|
|
|
log.info(f"Starting training with FP16_TRAINING={FP16_TRAINING}, optim='{OPTIMIZER_CHOICE}', LR={LEARNING_RATE}, GradClip={MAX_GRAD_NORM_CLIP}...")
|
|
try:
|
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
except Exception as e:
|
|
log.exception(f"Unhandled error during trainer.train(): {e}")
|
|
return
|
|
|
|
log.info("Training completed.")
|
|
try:
|
|
final_model_path = os.path.join(output_dir, "final_model_after_train")
|
|
if not os.path.exists(final_model_path):
|
|
trainer.save_model(final_model_path)
|
|
log.info(f"Final model state explicitly saved to {final_model_path}")
|
|
else:
|
|
log.info(f"Best model was likely saved by load_best_model_at_end to a checkpoint within {output_dir}")
|
|
except Exception as e:
|
|
log.exception(f"Error saving final explicit model: {e}")
|
|
log.info("Train function finished.")
|
|
|
|
def Inference(base_model_id_for_tokenizer: str, trained_model_output_dir: str):
|
|
log.info(f"\n--- Starting Inference ---")
|
|
|
|
path_to_load_model_from = trained_model_output_dir
|
|
potential_final_model = os.path.join(trained_model_output_dir, "final_model_after_train")
|
|
|
|
if os.path.exists(potential_final_model) and (os.path.exists(os.path.join(potential_final_model, "pytorch_model.bin")) or os.path.exists(os.path.join(potential_final_model, "model.safetensors"))):
|
|
path_to_load_model_from = potential_final_model
|
|
log.info(f"Found 'final_model_after_train' at: {path_to_load_model_from}")
|
|
else:
|
|
latest_checkpoint = find_latest_checkpoint(trained_model_output_dir)
|
|
if latest_checkpoint:
|
|
path_to_load_model_from = latest_checkpoint
|
|
log.info(f"Found latest checkpoint: {path_to_load_model_from}")
|
|
elif not (os.path.exists(os.path.join(path_to_load_model_from, "pytorch_model.bin")) or os.path.exists(os.path.join(path_to_load_model_from, "model.safetensors"))):
|
|
log.error(f"No valid model found in {trained_model_output_dir} or its subdirectories. Cannot run inference.")
|
|
return
|
|
|
|
log.info(f"Attempting to load fine-tuned model from: {path_to_load_model_from}")
|
|
|
|
try:
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(path_to_load_model_from, device_map="auto")
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(path_to_load_model_from, trust_remote_code=True)
|
|
except Exception:
|
|
log.warning(f"Could not load tokenizer from {path_to_load_model_from}, trying base {base_model_id_for_tokenizer}")
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id_for_tokenizer, trust_remote_code=True)
|
|
log.info(f"Successfully loaded model and tokenizer for inference. Model is on: {model.device}")
|
|
except Exception as e:
|
|
log.error(f"Failed to load model or tokenizer for inference: {e}")
|
|
return
|
|
|
|
device = next(model.parameters()).device
|
|
generate_and_log_fixed_sample(model, tokenizer, FIXED_PROMPT_FOR_GENERATION, device, log_prefix="Final Inference")
|
|
log.info(f"--- Inference Demo Finished ---")
|
|
|
|
def main():
|
|
Train('Tomlim/myt5-large', 'trained_model', 'DiscordPromptSD.json')
|
|
Inference('Tomlim/myt5-large', 'trained_model')
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|
|
|
|
|
|
|