myt5-large-SD-prompts / train_stable_myt5-large.py
Nekochu's picture
Add overall stable checkpoint-25000 from "trained_model" - Model: myt5-large
3cd73eb verified
import os
import json
import gc
import logging
import torch
import pickle
from torch.utils.data import Dataset
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
DataCollatorForSeq2Seq,
)
# CONFIGURATION
MAX_ITEMS = None
MAX_LENGTH = 256
PER_DEVICE_BATCH = 1
GRAD_ACC_STEPS = 16 # Increased due to higher MAX_LENGTH
LEARNING_RATE = 5e-5
NUM_TRAIN_EPOCHS = 1
WARMUP_STEPS = 200
FP16_TRAINING = False # fix windows
OPTIMIZER_CHOICE = "adamw_8bit"
MAX_GRAD_NORM_CLIP = 0.0
GRADIENT_CHECKPOINTING = True
LOGGING_STEPS = 50
SAVE_STEPS = 1000
EVAL_STEPS = 500
SAVE_TOTAL_LIMIT = 20 # each 7GB
FIXED_PROMPT_FOR_GENERATION = "Create stable diffusion metadata based on the given english description. a futuristic city"
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(levelname)s — %(name)s — %(message)s")
log = logging.getLogger(__name__)
class SDPromptDataset(Dataset):
def __init__(self, raw_data_list, tokenizer, max_length, dataset_type="train", cache_dir="cache"):
self.raw_data = raw_data_list
self.tokenizer = tokenizer
self.max_length = max_length
self.dataset_type = dataset_type
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, f"{dataset_type}_{len(raw_data_list)}_{max_length}.pkl")
if os.path.exists(cache_file):
log.info(f"Loading cached {dataset_type} dataset from {cache_file}")
with open(cache_file, 'rb') as f:
self.examples = pickle.load(f)
log.info(f"Loaded {len(self.examples)} cached examples for {dataset_type}")
else:
log.info(f"Tokenizing {len(raw_data_list)} samples for {dataset_type} with {type(tokenizer).__name__}...")
self.examples = []
for i, item in enumerate(raw_data_list):
if i > 0 and (i % 5000 == 0 or i == len(raw_data_list) - 1):
log.info(f"Tokenized {i+1} / {len(raw_data_list)} samples for {dataset_type}")
instruction = item.get("instruction", "")
output = item.get("output", "")
input_encoding = tokenizer(
instruction, max_length=max_length, padding="max_length",
truncation=True, return_tensors="pt",
)
if self.dataset_type == "train" or (self.dataset_type == "eval" and output):
target_encoding = tokenizer(
output, max_length=max_length, padding="max_length",
truncation=True, return_tensors="pt",
)
labels = target_encoding["input_ids"].squeeze()
labels[labels == tokenizer.pad_token_id] = -100
else:
labels = None
example_data = {
"input_ids": input_encoding["input_ids"].squeeze(),
"attention_mask": input_encoding["attention_mask"].squeeze(),
}
if labels is not None:
example_data["labels"] = labels
self.examples.append(example_data)
log.info(f"Tokenization complete for {dataset_type}. Saving cache to {cache_file}")
with open(cache_file, 'wb') as f:
pickle.dump(self.examples, f)
log.info(f"Cache saved successfully")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return self.examples[idx]
def get_raw_example(self, idx):
return self.raw_data[idx]
def load_and_split_json_data(data_path, max_items_from_config=None):
log.info(f"Loading data from {data_path}...")
if not os.path.exists(data_path):
log.error(f"Data file not found: {data_path}")
raise FileNotFoundError(f"Data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
all_data = json.load(f)
log.info(f"Successfully loaded {len(all_data)} total items from JSON.")
if max_items_from_config is not None and max_items_from_config > 0:
num_to_take = min(max_items_from_config, len(all_data))
log.info(f"Keeping the first {num_to_take} samples as per MAX_ITEMS config.")
all_data = all_data[:num_to_take]
else:
log.info("Using the full dataset.")
if not all_data:
log.error("No data loaded or remaining.")
raise ValueError("No data to process.")
if len(all_data) < 20:
split_idx = max(1, int(0.5 * len(all_data)))
log.warning(f"Dataset very small ({len(all_data)} items). Adjusting split.")
else:
split_idx = int(0.9 * len(all_data))
split_idx = max(1, split_idx)
train_data = all_data[:split_idx]
val_data = all_data[split_idx:]
if not val_data and train_data:
val_data = [train_data[-1]]
log.warning("Validation set was empty after split, using one sample from training data for validation.")
if len(train_data) > 1:
train_data = train_data[:-1]
val_data = val_data[:min(len(val_data), 2000)] if val_data else None
if not train_data:
log.error("Training data empty.")
raise ValueError("Training data empty.")
log.info(f"Train samples: {len(train_data)}, Validation samples: {len(val_data) if val_data else 0}")
return train_data, val_data
def find_latest_checkpoint(output_dir):
if not os.path.isdir(output_dir):
return None
checkpoints = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, d))]
if not checkpoints:
return None
checkpoints.sort(key=lambda x: int(x.split('-')[-1]))
latest_checkpoint = os.path.join(output_dir, checkpoints[-1])
if os.path.exists(os.path.join(latest_checkpoint, "pytorch_model.bin")) or os.path.exists(os.path.join(latest_checkpoint, "model.safetensors")):
return latest_checkpoint
return None
def clear_cuda_cache():
log.info("Clearing CUDA cache...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def generate_and_log_fixed_sample(model, tokenizer, prompt_text, device, log_prefix="Sample"):
log.info(f"\n--- {log_prefix} Generation ---")
log.info(f"Input Prompt: {prompt_text}")
model.eval()
inputs = tokenizer(prompt_text, return_tensors="pt", max_length=MAX_LENGTH, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs, max_length=MAX_LENGTH + 50,
num_beams=5, early_stopping=True, no_repeat_ngram_size=3,
temperature=0.7, top_k=50, top_p=0.95
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
log.info(f"Generated Output: {generated_text}")
log.info(f"--- End {log_prefix} Generation ---\n")
class ShowFixedEvalSampleCallback(TrainerCallback):
def __init__(self, tokenizer, prompt_text):
self.tokenizer = tokenizer
self.prompt_text = prompt_text
def on_evaluate(self, args, state, control, model=None, **kwargs):
if model is None:
return
device = next(model.parameters()).device
generate_and_log_fixed_sample(model, self.tokenizer, self.prompt_text, device, log_prefix="Evaluation Callback Sample")
model.train()
def Train(model_id: str, output_dir: str, data_path: str):
os.makedirs(output_dir, exist_ok=True)
clear_cuda_cache()
# Check for existing checkpoint to resume
resume_from_checkpoint = find_latest_checkpoint(output_dir)
if resume_from_checkpoint:
log.info(f"Found checkpoint to resume from: {resume_from_checkpoint}")
else:
log.info("No existing checkpoint found, starting fresh training")
log.info(f"Attempting to load MyT5Tokenizer for {model_id} (trust_remote_code=True).")
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
log.info(f"Successfully loaded tokenizer: {type(tokenizer).__name__}")
except Exception as e:
log.error(f"Failed to load tokenizer for {model_id} (trust_remote_code=True): {e}")
return
train_raw_data, eval_raw_data = load_and_split_json_data(data_path, max_items_from_config=MAX_ITEMS)
if not train_raw_data:
return
train_dataset = SDPromptDataset(train_raw_data, tokenizer, MAX_LENGTH, dataset_type="train")
eval_dataset = SDPromptDataset(eval_raw_data, tokenizer, MAX_LENGTH, dataset_type="eval") if eval_raw_data else None
log.info(f"Loading model: {model_id}")
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if FP16_TRAINING else torch.float32,
device_map="auto",
low_cpu_mem_usage=True,
)
if GRADIENT_CHECKPOINTING:
model.gradient_checkpointing_enable()
log.info("Grad-ckpt enabled.")
if OPTIMIZER_CHOICE == "adamw_8bit":
try:
import bitsandbytes
log.info(f"bitsandbytes version: {bitsandbytes.__version__} imported for adamw_8bit.")
except ImportError:
log.error("bitsandbytes not installed, required for optim='adamw_8bit'. Install: pip install bitsandbytes")
return
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=PER_DEVICE_BATCH,
per_device_eval_batch_size=PER_DEVICE_BATCH * 2,
gradient_accumulation_steps=GRAD_ACC_STEPS,
learning_rate=LEARNING_RATE,
num_train_epochs=NUM_TRAIN_EPOCHS,
warmup_steps=WARMUP_STEPS,
logging_steps=LOGGING_STEPS,
save_strategy="steps",
save_steps=SAVE_STEPS,
eval_strategy="steps" if eval_dataset else "no",
eval_steps=EVAL_STEPS if eval_dataset else None,
save_total_limit=SAVE_TOTAL_LIMIT,
load_best_model_at_end=True if eval_dataset else False,
fp16=FP16_TRAINING,
optim=OPTIMIZER_CHOICE,
max_grad_norm=MAX_GRAD_NORM_CLIP,
gradient_checkpointing=GRADIENT_CHECKPOINTING,
group_by_length=True,
lr_scheduler_type="cosine",
weight_decay=0.01,
report_to="none",
)
fixed_sample_callback = ShowFixedEvalSampleCallback(tokenizer=tokenizer, prompt_text=FIXED_PROMPT_FOR_GENERATION)
callbacks_to_use = [fixed_sample_callback] if eval_dataset else []
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
callbacks=callbacks_to_use
)
log.info(f"Starting training with FP16_TRAINING={FP16_TRAINING}, optim='{OPTIMIZER_CHOICE}', LR={LEARNING_RATE}, GradClip={MAX_GRAD_NORM_CLIP}...")
try:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
except Exception as e:
log.exception(f"Unhandled error during trainer.train(): {e}")
return
log.info("Training completed.")
try:
final_model_path = os.path.join(output_dir, "final_model_after_train")
if not os.path.exists(final_model_path):
trainer.save_model(final_model_path)
log.info(f"Final model state explicitly saved to {final_model_path}")
else:
log.info(f"Best model was likely saved by load_best_model_at_end to a checkpoint within {output_dir}")
except Exception as e:
log.exception(f"Error saving final explicit model: {e}")
log.info("Train function finished.")
def Inference(base_model_id_for_tokenizer: str, trained_model_output_dir: str):
log.info(f"\n--- Starting Inference ---")
path_to_load_model_from = trained_model_output_dir
potential_final_model = os.path.join(trained_model_output_dir, "final_model_after_train")
if os.path.exists(potential_final_model) and (os.path.exists(os.path.join(potential_final_model, "pytorch_model.bin")) or os.path.exists(os.path.join(potential_final_model, "model.safetensors"))):
path_to_load_model_from = potential_final_model
log.info(f"Found 'final_model_after_train' at: {path_to_load_model_from}")
else:
latest_checkpoint = find_latest_checkpoint(trained_model_output_dir)
if latest_checkpoint:
path_to_load_model_from = latest_checkpoint
log.info(f"Found latest checkpoint: {path_to_load_model_from}")
elif not (os.path.exists(os.path.join(path_to_load_model_from, "pytorch_model.bin")) or os.path.exists(os.path.join(path_to_load_model_from, "model.safetensors"))):
log.error(f"No valid model found in {trained_model_output_dir} or its subdirectories. Cannot run inference.")
return
log.info(f"Attempting to load fine-tuned model from: {path_to_load_model_from}")
try:
model = AutoModelForSeq2SeqLM.from_pretrained(path_to_load_model_from, device_map="auto")
try:
tokenizer = AutoTokenizer.from_pretrained(path_to_load_model_from, trust_remote_code=True)
except Exception:
log.warning(f"Could not load tokenizer from {path_to_load_model_from}, trying base {base_model_id_for_tokenizer}")
tokenizer = AutoTokenizer.from_pretrained(base_model_id_for_tokenizer, trust_remote_code=True)
log.info(f"Successfully loaded model and tokenizer for inference. Model is on: {model.device}")
except Exception as e:
log.error(f"Failed to load model or tokenizer for inference: {e}")
return
device = next(model.parameters()).device
generate_and_log_fixed_sample(model, tokenizer, FIXED_PROMPT_FOR_GENERATION, device, log_prefix="Final Inference")
log.info(f"--- Inference Demo Finished ---")
def main():
Train('Tomlim/myt5-large', 'trained_model', 'DiscordPromptSD.json')
Inference('Tomlim/myt5-large', 'trained_model')
if __name__ == "__main__":
main()
# - SFW Cyberpunk City: `Nikon Z9 200mm f_8 ISO 160, (giant rifle structure), flawless ornate architecture, cyberpunk, neon lights, busy street, realistic, ray tracing, hasselblad`
# - **SFW Fantasy Dragon Rider: `masterpiece, best quality, cinematic lighting, 1girl, solo, <lora:add_detail:0.55>`
# - **NSFW Anime Succubus: `masterpiece, best quality, highly detailed background, intricate, 1girl, (full-face blush, aroused:1.3), long hair, medium breasts, nipples`