import os import json import gc import logging import torch import pickle from torch.utils.data import Dataset from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, DataCollatorForSeq2Seq, ) # CONFIGURATION MAX_ITEMS = None MAX_LENGTH = 256 PER_DEVICE_BATCH = 1 GRAD_ACC_STEPS = 16 # Increased due to higher MAX_LENGTH LEARNING_RATE = 5e-5 NUM_TRAIN_EPOCHS = 1 WARMUP_STEPS = 200 FP16_TRAINING = False # fix windows OPTIMIZER_CHOICE = "adamw_8bit" MAX_GRAD_NORM_CLIP = 0.0 GRADIENT_CHECKPOINTING = True LOGGING_STEPS = 50 SAVE_STEPS = 1000 EVAL_STEPS = 500 SAVE_TOTAL_LIMIT = 20 # each 7GB FIXED_PROMPT_FOR_GENERATION = "Create stable diffusion metadata based on the given english description. a futuristic city" logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(levelname)s — %(name)s — %(message)s") log = logging.getLogger(__name__) class SDPromptDataset(Dataset): def __init__(self, raw_data_list, tokenizer, max_length, dataset_type="train", cache_dir="cache"): self.raw_data = raw_data_list self.tokenizer = tokenizer self.max_length = max_length self.dataset_type = dataset_type os.makedirs(cache_dir, exist_ok=True) cache_file = os.path.join(cache_dir, f"{dataset_type}_{len(raw_data_list)}_{max_length}.pkl") if os.path.exists(cache_file): log.info(f"Loading cached {dataset_type} dataset from {cache_file}") with open(cache_file, 'rb') as f: self.examples = pickle.load(f) log.info(f"Loaded {len(self.examples)} cached examples for {dataset_type}") else: log.info(f"Tokenizing {len(raw_data_list)} samples for {dataset_type} with {type(tokenizer).__name__}...") self.examples = [] for i, item in enumerate(raw_data_list): if i > 0 and (i % 5000 == 0 or i == len(raw_data_list) - 1): log.info(f"Tokenized {i+1} / {len(raw_data_list)} samples for {dataset_type}") instruction = item.get("instruction", "") output = item.get("output", "") input_encoding = tokenizer( instruction, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", ) if self.dataset_type == "train" or (self.dataset_type == "eval" and output): target_encoding = tokenizer( output, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", ) labels = target_encoding["input_ids"].squeeze() labels[labels == tokenizer.pad_token_id] = -100 else: labels = None example_data = { "input_ids": input_encoding["input_ids"].squeeze(), "attention_mask": input_encoding["attention_mask"].squeeze(), } if labels is not None: example_data["labels"] = labels self.examples.append(example_data) log.info(f"Tokenization complete for {dataset_type}. Saving cache to {cache_file}") with open(cache_file, 'wb') as f: pickle.dump(self.examples, f) log.info(f"Cache saved successfully") def __len__(self): return len(self.examples) def __getitem__(self, idx): return self.examples[idx] def get_raw_example(self, idx): return self.raw_data[idx] def load_and_split_json_data(data_path, max_items_from_config=None): log.info(f"Loading data from {data_path}...") if not os.path.exists(data_path): log.error(f"Data file not found: {data_path}") raise FileNotFoundError(f"Data file not found: {data_path}") with open(data_path, "r", encoding="utf-8") as f: all_data = json.load(f) log.info(f"Successfully loaded {len(all_data)} total items from JSON.") if max_items_from_config is not None and max_items_from_config > 0: num_to_take = min(max_items_from_config, len(all_data)) log.info(f"Keeping the first {num_to_take} samples as per MAX_ITEMS config.") all_data = all_data[:num_to_take] else: log.info("Using the full dataset.") if not all_data: log.error("No data loaded or remaining.") raise ValueError("No data to process.") if len(all_data) < 20: split_idx = max(1, int(0.5 * len(all_data))) log.warning(f"Dataset very small ({len(all_data)} items). Adjusting split.") else: split_idx = int(0.9 * len(all_data)) split_idx = max(1, split_idx) train_data = all_data[:split_idx] val_data = all_data[split_idx:] if not val_data and train_data: val_data = [train_data[-1]] log.warning("Validation set was empty after split, using one sample from training data for validation.") if len(train_data) > 1: train_data = train_data[:-1] val_data = val_data[:min(len(val_data), 2000)] if val_data else None if not train_data: log.error("Training data empty.") raise ValueError("Training data empty.") log.info(f"Train samples: {len(train_data)}, Validation samples: {len(val_data) if val_data else 0}") return train_data, val_data def find_latest_checkpoint(output_dir): if not os.path.isdir(output_dir): return None checkpoints = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, d))] if not checkpoints: return None checkpoints.sort(key=lambda x: int(x.split('-')[-1])) latest_checkpoint = os.path.join(output_dir, checkpoints[-1]) if os.path.exists(os.path.join(latest_checkpoint, "pytorch_model.bin")) or os.path.exists(os.path.join(latest_checkpoint, "model.safetensors")): return latest_checkpoint return None def clear_cuda_cache(): log.info("Clearing CUDA cache...") gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def generate_and_log_fixed_sample(model, tokenizer, prompt_text, device, log_prefix="Sample"): log.info(f"\n--- {log_prefix} Generation ---") log.info(f"Input Prompt: {prompt_text}") model.eval() inputs = tokenizer(prompt_text, return_tensors="pt", max_length=MAX_LENGTH, truncation=True) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_length=MAX_LENGTH + 50, num_beams=5, early_stopping=True, no_repeat_ngram_size=3, temperature=0.7, top_k=50, top_p=0.95 ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) log.info(f"Generated Output: {generated_text}") log.info(f"--- End {log_prefix} Generation ---\n") class ShowFixedEvalSampleCallback(TrainerCallback): def __init__(self, tokenizer, prompt_text): self.tokenizer = tokenizer self.prompt_text = prompt_text def on_evaluate(self, args, state, control, model=None, **kwargs): if model is None: return device = next(model.parameters()).device generate_and_log_fixed_sample(model, self.tokenizer, self.prompt_text, device, log_prefix="Evaluation Callback Sample") model.train() def Train(model_id: str, output_dir: str, data_path: str): os.makedirs(output_dir, exist_ok=True) clear_cuda_cache() # Check for existing checkpoint to resume resume_from_checkpoint = find_latest_checkpoint(output_dir) if resume_from_checkpoint: log.info(f"Found checkpoint to resume from: {resume_from_checkpoint}") else: log.info("No existing checkpoint found, starting fresh training") log.info(f"Attempting to load MyT5Tokenizer for {model_id} (trust_remote_code=True).") try: tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) log.info(f"Successfully loaded tokenizer: {type(tokenizer).__name__}") except Exception as e: log.error(f"Failed to load tokenizer for {model_id} (trust_remote_code=True): {e}") return train_raw_data, eval_raw_data = load_and_split_json_data(data_path, max_items_from_config=MAX_ITEMS) if not train_raw_data: return train_dataset = SDPromptDataset(train_raw_data, tokenizer, MAX_LENGTH, dataset_type="train") eval_dataset = SDPromptDataset(eval_raw_data, tokenizer, MAX_LENGTH, dataset_type="eval") if eval_raw_data else None log.info(f"Loading model: {model_id}") model = AutoModelForSeq2SeqLM.from_pretrained( model_id, torch_dtype=torch.float16 if FP16_TRAINING else torch.float32, device_map="auto", low_cpu_mem_usage=True, ) if GRADIENT_CHECKPOINTING: model.gradient_checkpointing_enable() log.info("Grad-ckpt enabled.") if OPTIMIZER_CHOICE == "adamw_8bit": try: import bitsandbytes log.info(f"bitsandbytes version: {bitsandbytes.__version__} imported for adamw_8bit.") except ImportError: log.error("bitsandbytes not installed, required for optim='adamw_8bit'. Install: pip install bitsandbytes") return training_args = TrainingArguments( output_dir=output_dir, per_device_train_batch_size=PER_DEVICE_BATCH, per_device_eval_batch_size=PER_DEVICE_BATCH * 2, gradient_accumulation_steps=GRAD_ACC_STEPS, learning_rate=LEARNING_RATE, num_train_epochs=NUM_TRAIN_EPOCHS, warmup_steps=WARMUP_STEPS, logging_steps=LOGGING_STEPS, save_strategy="steps", save_steps=SAVE_STEPS, eval_strategy="steps" if eval_dataset else "no", eval_steps=EVAL_STEPS if eval_dataset else None, save_total_limit=SAVE_TOTAL_LIMIT, load_best_model_at_end=True if eval_dataset else False, fp16=FP16_TRAINING, optim=OPTIMIZER_CHOICE, max_grad_norm=MAX_GRAD_NORM_CLIP, gradient_checkpointing=GRADIENT_CHECKPOINTING, group_by_length=True, lr_scheduler_type="cosine", weight_decay=0.01, report_to="none", ) fixed_sample_callback = ShowFixedEvalSampleCallback(tokenizer=tokenizer, prompt_text=FIXED_PROMPT_FOR_GENERATION) callbacks_to_use = [fixed_sample_callback] if eval_dataset else [] data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest") trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, tokenizer=tokenizer, callbacks=callbacks_to_use ) log.info(f"Starting training with FP16_TRAINING={FP16_TRAINING}, optim='{OPTIMIZER_CHOICE}', LR={LEARNING_RATE}, GradClip={MAX_GRAD_NORM_CLIP}...") try: trainer.train(resume_from_checkpoint=resume_from_checkpoint) except Exception as e: log.exception(f"Unhandled error during trainer.train(): {e}") return log.info("Training completed.") try: final_model_path = os.path.join(output_dir, "final_model_after_train") if not os.path.exists(final_model_path): trainer.save_model(final_model_path) log.info(f"Final model state explicitly saved to {final_model_path}") else: log.info(f"Best model was likely saved by load_best_model_at_end to a checkpoint within {output_dir}") except Exception as e: log.exception(f"Error saving final explicit model: {e}") log.info("Train function finished.") def Inference(base_model_id_for_tokenizer: str, trained_model_output_dir: str): log.info(f"\n--- Starting Inference ---") path_to_load_model_from = trained_model_output_dir potential_final_model = os.path.join(trained_model_output_dir, "final_model_after_train") if os.path.exists(potential_final_model) and (os.path.exists(os.path.join(potential_final_model, "pytorch_model.bin")) or os.path.exists(os.path.join(potential_final_model, "model.safetensors"))): path_to_load_model_from = potential_final_model log.info(f"Found 'final_model_after_train' at: {path_to_load_model_from}") else: latest_checkpoint = find_latest_checkpoint(trained_model_output_dir) if latest_checkpoint: path_to_load_model_from = latest_checkpoint log.info(f"Found latest checkpoint: {path_to_load_model_from}") elif not (os.path.exists(os.path.join(path_to_load_model_from, "pytorch_model.bin")) or os.path.exists(os.path.join(path_to_load_model_from, "model.safetensors"))): log.error(f"No valid model found in {trained_model_output_dir} or its subdirectories. Cannot run inference.") return log.info(f"Attempting to load fine-tuned model from: {path_to_load_model_from}") try: model = AutoModelForSeq2SeqLM.from_pretrained(path_to_load_model_from, device_map="auto") try: tokenizer = AutoTokenizer.from_pretrained(path_to_load_model_from, trust_remote_code=True) except Exception: log.warning(f"Could not load tokenizer from {path_to_load_model_from}, trying base {base_model_id_for_tokenizer}") tokenizer = AutoTokenizer.from_pretrained(base_model_id_for_tokenizer, trust_remote_code=True) log.info(f"Successfully loaded model and tokenizer for inference. Model is on: {model.device}") except Exception as e: log.error(f"Failed to load model or tokenizer for inference: {e}") return device = next(model.parameters()).device generate_and_log_fixed_sample(model, tokenizer, FIXED_PROMPT_FOR_GENERATION, device, log_prefix="Final Inference") log.info(f"--- Inference Demo Finished ---") def main(): Train('Tomlim/myt5-large', 'trained_model', 'DiscordPromptSD.json') Inference('Tomlim/myt5-large', 'trained_model') if __name__ == "__main__": main() # - SFW Cyberpunk City: `Nikon Z9 200mm f_8 ISO 160, (giant rifle structure), flawless ornate architecture, cyberpunk, neon lights, busy street, realistic, ray tracing, hasselblad` # - **SFW Fantasy Dragon Rider: `masterpiece, best quality, cinematic lighting, 1girl, solo, ` # - **NSFW Anime Succubus: `masterpiece, best quality, highly detailed background, intricate, 1girl, (full-face blush, aroused:1.3), long hair, medium breasts, nipples`