| |
| |
| """ |
| Pretrain Veronica-Polymorphic from scratch (clean mixture: FinePDFs / DCLM / FineWeb-Edu). |
| |
| Basic example: |
| python veronica-polymorphic/scripts/train_veronica.py \ |
| --config veronica-polymorphic/configs/veronica-pretrain-12L.json \ |
| --dataset_paths data/mix_optimal_50_30_20_2048 \ |
| --output_dir veronica-polymorphic/runs/veronica-pretrain-vMix-2048 \ |
| --per_device_train_batch_size 4 \ |
| --gradient_accumulation_steps 4 \ |
| --learning_rate 2e-4 \ |
| --label_smoothing 0.01 \ |
| --rep_alpha 0.0 \ |
| --max_steps 60000 \ |
| --max_seq_len 2048 |
| |
| You can use different datasets (e.g., 512 / 1024 / 2048) in separate runs for length curriculum. |
| """ |
|
|
| import os |
| import re |
| import glob |
| import json |
| import math |
| import argparse |
| import random |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_from_disk |
| from transformers import ( |
| AutoTokenizer, |
| Trainer, |
| TrainingArguments, |
| TrainerCallback, |
| CONFIG_MAPPING, |
| MODEL_FOR_CAUSAL_LM_MAPPING, |
| LogitsProcessorList, |
| NoRepeatNGramLogitsProcessor, |
| RepetitionPenaltyLogitsProcessor, |
| ) |
|
|
| |
| from veronica.configuration_veronica import VeronicaConfig |
| from veronica.modeling_veronica import VeronicaForCausalLM |
| from veronica.modeling_components import Fp32LayerNorm |
|
|
| CONFIG_MAPPING.register("veronica", VeronicaConfig) |
| MODEL_FOR_CAUSAL_LM_MAPPING.register(VeronicaConfig, VeronicaForCausalLM) |
|
|
| |
| os.environ.setdefault("TORCH_COMPILE_USE_CUDAGRAPHS", "0") |
| os.environ.setdefault("TORCHINDUCTOR_DISABLE_CUDAGRAPHS", "1") |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
|
| |
| |
| |
|
|
| def find_latest_checkpoint(run_dir: str) -> Optional[str]: |
| ckpts = glob.glob(os.path.join(run_dir, "checkpoint-*")) |
| if not ckpts: |
| return None |
| ckpts.sort(key=lambda p: int(re.search(r"checkpoint-(\d+)", p).group(1))) |
| return ckpts[-1] |
|
|
|
|
| def build_tokenizer(candidates: List[str], save_dir: str) -> AutoTokenizer: |
| """ |
| Try to load an existing tokenizer from the provided paths; |
| otherwise fallback to gpt2 and add basic special tokens. |
| """ |
| tok = None |
| for p in candidates: |
| if os.path.exists(p): |
| try: |
| tok = AutoTokenizer.from_pretrained(p, use_fast=True) |
| print(f"[tokenizer] loaded from {p}") |
| break |
| except Exception: |
| pass |
| if tok is None: |
| print("[tokenizer] fallback: gpt2") |
| tok = AutoTokenizer.from_pretrained("gpt2", use_fast=True) |
|
|
| specials: Dict[str, str] = {} |
| if tok.eos_token is None: |
| specials["eos_token"] = "<|eos|>" |
| if tok.pad_token is None: |
| specials["pad_token"] = "<|pad|>" |
| if tok.bos_token is None: |
| specials["bos_token"] = "<|bos|>" |
|
|
| if specials: |
| tok.add_special_tokens(specials) |
|
|
| tok.save_pretrained(save_dir) |
| tok = AutoTokenizer.from_pretrained(save_dir, use_fast=True) |
| base_vocab = tok.vocab_size |
| effective_vocab = len(tok) |
| print( |
| f"[tokenizer] base_vocab={base_vocab} added={effective_vocab - base_vocab} " |
| f"effective_vocab={effective_vocab} eos={tok.eos_token_id} " |
| f"pad={tok.pad_token_id} bos={tok.bos_token_id}" |
| ) |
| return tok |
|
|
|
|
| def load_cfg_with_vocab(cfg_path: str, tok: AutoTokenizer) -> VeronicaConfig: |
| """ |
| Load the config and adapt it to the tokenizer vocabulary. |
| Model is designed as UN-TIED (lm_head != wte). |
| """ |
| with open(cfg_path, "r", encoding="utf-8") as f: |
| d = json.load(f) |
| cfg = VeronicaConfig(**d) |
| cfg.model_type = "veronica" |
| cfg.vocab_size = int(len(tok)) |
| |
| return cfg |
|
|
|
|
| def init_model_from_config(cfg: VeronicaConfig, tok: AutoTokenizer) -> VeronicaForCausalLM: |
| model = VeronicaForCausalLM(cfg) |
| use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
| dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(dtype=dtype, device=device) |
|
|
| effective_vocab = len(tok) |
| emb = model.get_input_embeddings().weight |
| head = model.lm_head.weight |
|
|
| |
| if emb.shape[0] != effective_vocab or head.shape[0] != effective_vocab: |
| old_vocab = emb.shape[0] |
| print(f"[model] resize_token_embeddings: {old_vocab} -> {effective_vocab}") |
| model.resize_token_embeddings(effective_vocab) |
| with torch.no_grad(): |
| new_emb = model.get_input_embeddings().weight |
| new_head = model.lm_head.weight |
| mean_emb = new_emb[:old_vocab].mean(dim=0, keepdim=True) |
| mean_head = new_head[:old_vocab].mean(dim=0, keepdim=True) |
| if effective_vocab > old_vocab: |
| new_emb[old_vocab:] = mean_emb |
| new_head[old_vocab:] = mean_head |
|
|
| |
| for m in model.modules(): |
| if isinstance(m, Fp32LayerNorm): |
| m.ln.to(dtype=torch.float32) |
|
|
| model.config.use_cache = False |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"[model] params={n_params:,} vocab={effective_vocab}") |
| return model |
|
|
|
|
| def load_mix_dataset(path: str): |
| """ |
| Load a packed dataset (train/validation) from disk. |
| Expected HuggingFace formats: a DatasetDict with 'train' and 'validation', |
| or a single Dataset that gets split 99/1. |
| """ |
| ds = load_from_disk(path) |
| if isinstance(ds, dict) and "train" in ds and "validation" in ds: |
| return ds["train"], ds["validation"] |
| split = ds.train_test_split(test_size=0.01, seed=42) |
| return split["train"], split["test"] |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class CausalCollator: |
| tokenizer: AutoTokenizer |
| mask_runs: bool = False |
| run_len: int = 4 |
| max_seq_len: Optional[int] = None |
|
|
| def _mask_degenerate_runs(self, labels: torch.Tensor): |
| """ |
| Mask degenerate runs (e.g., '____', '....') with length >= run_len. |
| Mostly legacy; can be left off with a clean dataset. |
| """ |
| try: |
| id_us = self.tokenizer.encode("_", add_special_tokens=False)[0] |
| id_dot = self.tokenizer.encode(".", add_special_tokens=False)[0] |
| except Exception: |
| return |
| B, T = labels.size() |
| for b in range(B): |
| cnt_u = cnt_d = 0 |
| for t in range(T): |
| tok = int(labels[b, t].item()) |
| if tok == id_us: |
| cnt_u += 1 |
| cnt_d = 0 |
| elif tok == id_dot: |
| cnt_d += 1 |
| cnt_u = 0 |
| else: |
| cnt_u = cnt_d = 0 |
| if cnt_u >= self.run_len or cnt_d >= self.run_len: |
| labels[b, t] = -100 |
|
|
| def _crop(self, ids: torch.Tensor) -> torch.Tensor: |
| """ |
| If max_seq_len is set and the sequence is longer, |
| crop a random window of length max_seq_len. |
| """ |
| if self.max_seq_len is None: |
| return ids |
| L = ids.size(0) |
| if L <= self.max_seq_len: |
| return ids |
| start = random.randint(0, L - self.max_seq_len) |
| end = start + self.max_seq_len |
| return ids[start:end] |
|
|
| def __call__(self, features): |
| ids_list = [] |
| for f in features: |
| ids = torch.tensor(f["input_ids"], dtype=torch.long) |
| ids = self._crop(ids) |
| ids_list.append(ids) |
|
|
| pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id |
| ids = torch.nn.utils.rnn.pad_sequence(ids_list, batch_first=True, padding_value=pad_id) |
| attn = torch.where(ids == pad_id, 0, 1) |
|
|
| labels = ids.clone() |
| labels[labels == pad_id] = -100 |
| if self.mask_runs: |
| self._mask_degenerate_runs(labels) |
|
|
| return {"input_ids": ids, "attention_mask": attn, "labels": labels} |
|
|
|
|
| |
| |
| |
|
|
| SMOKE_PROMPTS = [ |
| "The world we live in today is", |
| "Understanding complex ideas requires", |
| "Human intelligence differs from artificial intelligence because", |
| "A good system design is based on", |
| "In the middle of every difficulty lies", |
| "Once upon a time, there was a scientist who", |
| ] |
|
|
|
|
| class RouterAndSmokeCallback(TrainerCallback): |
| def __init__(self, tok: AutoTokenizer): |
| self.tok = tok |
|
|
| def on_log(self, args, state, control, **kwargs): |
| model = kwargs.get("model", None) |
| if model is None: |
| return |
| try: |
| if hasattr(model, "router_alpha_mean") and model.router_alpha_mean is not None: |
| alpha = model.router_alpha_mean.detach().float().cpu() |
| p = alpha / alpha.sum() |
| ent = -(p * (p.clamp_min(1e-9)).log()).sum() |
| ent_norm = float(ent / math.log(len(p))) |
| print(f"[router] alpha={alpha.tolist()} entropy_norm={ent_norm:.4f}") |
| except Exception: |
| pass |
|
|
| def on_evaluate(self, args, state, control, **kwargs): |
| model = kwargs.get("model", None) |
| if model is None: |
| return |
| model.eval() |
| dev = next(model.parameters()).device |
|
|
| prompt = random.choice(SMOKE_PROMPTS) |
| ids = self.tok(prompt, return_tensors="pt").to(dev) |
|
|
| processors = LogitsProcessorList([ |
| NoRepeatNGramLogitsProcessor(3), |
| RepetitionPenaltyLogitsProcessor(1.1), |
| ]) |
|
|
| with torch.no_grad(): |
| out = model.generate( |
| **ids, |
| max_new_tokens=64, |
| do_sample=False, |
| logits_processor=processors, |
| eos_token_id=self.tok.eos_token_id, |
| pad_token_id=(self.tok.pad_token_id or self.tok.eos_token_id), |
| use_cache=True, |
| ) |
| txt = self.tok.decode(out[0], skip_special_tokens=True) |
| completion = txt[len(prompt):].strip() if txt.startswith(prompt) else txt |
| print(f"\n[SMOKE] {prompt} → {completion}\n") |
| model.train() |
|
|
|
|
| |
| |
| |
|
|
| class RouterScheduleCallback(TrainerCallback): |
| """ |
| Linearly schedule router_tau and router_aux_weight between start and end of training. |
| """ |
|
|
| def __init__( |
| self, |
| tau_start: float, |
| tau_end: float, |
| aux_start: float, |
| aux_end: float, |
| total_steps: int, |
| tau_freeze_steps: int = 0, |
| force_prob: float = 0.0, |
| force_warmup_steps: int = 0, |
| ): |
| self.tau_start = float(tau_start) |
| self.tau_end = float(tau_end) |
| self.aux_start = float(aux_start) |
| self.aux_end = float(aux_end) |
| self.total_steps = max(int(total_steps), 1) |
| self.tau_freeze_steps = max(int(tau_freeze_steps), 0) |
| self.force_prob = float(force_prob) |
| self.force_warmup_steps = max(int(force_warmup_steps), 0) |
|
|
| def _interp(self, start: float, end: float, step: int, span: int) -> float: |
| t = min(max(step, 0), span) |
| alpha = t / float(max(span, 1)) |
| return (1.0 - alpha) * start + alpha * end |
|
|
| def on_step_begin(self, args, state, control, **kwargs): |
| model = kwargs.get("model", None) |
| if model is None: |
| return |
| step = state.global_step |
| |
| if step < self.tau_freeze_steps: |
| new_tau = self.tau_start |
| else: |
| rem_step = step - self.tau_freeze_steps |
| rem_span = max(self.total_steps - self.tau_freeze_steps, 1) |
| new_tau = self._interp(self.tau_start, self.tau_end, rem_step, rem_span) |
|
|
| |
| new_aux = self._interp(self.aux_start, self.aux_end, step, self.total_steps) |
|
|
| |
| if hasattr(model, "config"): |
| model.config.router_tau = new_tau |
| model.config.router_aux_weight = new_aux |
|
|
| |
| for block in getattr(model, "blocks", []): |
| if hasattr(block, "mlp"): |
| |
| block.mlp.router_tau = new_tau |
| block.mlp.force_func = -1 |
|
|
| |
| if step < self.force_warmup_steps and self.force_prob > 0.0: |
| if random.random() < self.force_prob: |
| for block in getattr(model, "blocks", []): |
| if hasattr(block, "mlp") and hasattr(block.mlp, "num_funcs"): |
| k = block.mlp.num_funcs |
| block.mlp.force_func = random.randint(0, max(k - 1, 0)) |
|
|
| if step % 1000 == 0: |
| print( |
| f"[router-sched] step={step} tau={new_tau:.4f} aux_w={new_aux:.5f} " |
| f"freeze<= {self.tau_freeze_steps} force_p={self.force_prob:.3f} warmup<= {self.force_warmup_steps}" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class VeronicaTrainer(Trainer): |
| def __init__(self, *args, label_smoothing: float = 0.0, rep_alpha: float = 0.0, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.label_smoothing = float(label_smoothing) |
| self.rep_alpha = float(rep_alpha) |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| labels = inputs.get("labels") |
| if labels is None: |
| raise ValueError("compute_loss called without labels") |
| model_inputs = {k: v for k, v in inputs.items() if k != "labels"} |
|
|
| outputs = model(**model_inputs) |
| logits = outputs.logits |
|
|
| ignore_index = -100 |
| |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
|
|
| valid_mask = (shift_labels != ignore_index) |
| safe_labels = shift_labels.clone() |
| safe_labels[~valid_mask] = 0 |
|
|
| log_probs = F.log_softmax(shift_logits, dim=-1) |
| nll_full = -log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) |
| nll_loss = nll_full[valid_mask].mean() |
|
|
| if self.label_smoothing > 0.0: |
| smooth_full = -log_probs.mean(dim=-1) |
| smooth_loss = smooth_full[valid_mask].mean() |
| ce_loss = (1.0 - self.label_smoothing) * nll_loss + self.label_smoothing * smooth_loss |
| else: |
| ce_loss = nll_loss |
|
|
| total_loss = ce_loss |
|
|
| |
| if self.rep_alpha > 0.0: |
| labels_prev = labels[:, :-1] |
| labels_next = shift_labels |
| valid_prev = (labels_prev != ignore_index) |
| same_mask = valid_prev & valid_mask & (labels_prev == labels_next) |
| if same_mask.any(): |
| rep_logp = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) |
| rep_p = rep_logp[same_mask].exp() |
| total_loss = total_loss + self.rep_alpha * rep_p.mean() |
|
|
| |
| aux_loss = getattr(model, "_last_router_aux", None) |
| if aux_loss is not None and hasattr(model, "config"): |
| aux_w = float(getattr(model.config, "router_aux_weight", 0.0)) |
| if aux_w > 0: |
| if not torch.is_tensor(aux_loss): |
| aux_loss = torch.as_tensor(aux_loss, device=logits.device, dtype=logits.dtype) |
| |
| total_loss = total_loss - aux_w * aux_loss.clamp_min(0.0) |
|
|
| return (total_loss, outputs) if return_outputs else total_loss |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--dataset_paths", type=str, required=True) |
| parser.add_argument("--output_dir", type=str, required=True, default="veronica-polymorphic/runs/veronica-pretrain") |
|
|
| parser.add_argument( |
| "--tokenizer_candidates", |
| type=str, |
| nargs="*", |
| default=["veronica-polymorphic/tokenizer", "gpt2"], |
| ) |
| parser.add_argument( |
| "--tokenizer_out", |
| type=str, |
| default="veronica-polymorphic/tokenizer_vmix", |
| ) |
|
|
| parser.add_argument("--per_device_train_batch_size", type=int, default=4) |
| parser.add_argument("--per_device_eval_batch_size", type=int, default=4) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--max_steps", type=int, default=60000) |
| parser.add_argument("--learning_rate", type=float, default=2e-4) |
| parser.add_argument("--warmup_ratio", type=float, default=0.02) |
| parser.add_argument("--weight_decay", type=float, default=0.1) |
| parser.add_argument("--eval_steps", type=int, default=1000) |
| parser.add_argument("--save_steps", type=int, default=1000) |
| parser.add_argument("--logging_steps", type=int, default=100) |
| parser.add_argument("--label_smoothing", type=float, default=0.01) |
| parser.add_argument("--rep_alpha", type=float, default=0.0) |
| parser.add_argument("--mask_degenerate_runs", action="store_true") |
| parser.add_argument("--seed", type=int, default=42) |
|
|
| parser.add_argument( |
| "--resume_from", |
| type=str, |
| default=None, |
| help="Checkpoint to resume from (e.g., .../checkpoint-22000)", |
| ) |
|
|
| parser.add_argument( |
| "--max_seq_len", |
| type=int, |
| default=None, |
| help="Maximum window length (e.g., 512, 1024, 2048). If None, uses the full dataset sequence.", |
| ) |
|
|
| |
| parser.add_argument("--router_tau_start", type=float, default=1.6) |
| parser.add_argument("--router_tau_end", type=float, default=1.1) |
| parser.add_argument("--router_aux_start", type=float, default=0.005) |
| parser.add_argument("--router_aux_end", type=float, default=0.012) |
| parser.add_argument("--router_tau_freeze_steps", type=int, default=4000, |
| help="Keep tau constant for the first N steps to avoid early specialization.") |
| parser.add_argument("--router_force_prob", type=float, default=0.05, |
| help="Per-step probability to force a single branch during warmup.") |
| parser.add_argument("--router_force_warmup_steps", type=int, default=3000, |
| help="Apply random branch forcing only within these initial steps.") |
|
|
| args = parser.parse_args() |
|
|
| |
| tok = build_tokenizer(args.tokenizer_candidates, args.tokenizer_out) |
|
|
| |
| cfg = load_cfg_with_vocab(args.config, tok) |
| cfg.router_tau = args.router_tau_start |
| cfg.router_aux_weight = args.router_aux_start |
|
|
| model = init_model_from_config(cfg, tok) |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| dummy = torch.randint(0, model.config.vocab_size, (1, 32), device=model.device) |
| out = model(input_ids=dummy, labels=dummy) |
| loss_model = out.loss.item() |
|
|
| logits = out.logits |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = dummy[:, 1:].contiguous() |
| loss_manual = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1) |
| ).item() |
|
|
| print(f"[diag] loss_model_forward={loss_model:.4f} loss_manual_shift={loss_manual:.4f}") |
| model.train() |
|
|
| |
| train_ds, val_ds = load_mix_dataset(args.dataset_paths) |
| collator = CausalCollator( |
| tokenizer=tok, |
| mask_runs=args.mask_degenerate_runs, |
| max_seq_len=args.max_seq_len, |
| ) |
|
|
| |
| resume_ckpt = args.resume_from or find_latest_checkpoint(args.output_dir) |
| if resume_ckpt: |
| print(f"🟢 Resuming from: {resume_ckpt}") |
| else: |
| print("⚪ No checkpoint: training from scratch.") |
|
|
| use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
|
|
| train_args = TrainingArguments( |
| output_dir=args.output_dir, |
| run_name=os.path.basename(args.output_dir.rstrip("/")), |
| num_train_epochs=1_000, |
| max_steps=args.max_steps, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| per_device_eval_batch_size=args.per_device_eval_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| warmup_ratio=args.warmup_ratio, |
| weight_decay=args.weight_decay, |
| lr_scheduler_type="cosine", |
| logging_steps=args.logging_steps, |
| eval_steps=args.eval_steps, |
| save_steps=args.save_steps, |
| eval_strategy="steps", |
| save_total_limit=5, |
| bf16=use_bf16, |
| fp16=(torch.cuda.is_available() and not use_bf16), |
| gradient_checkpointing=True, |
| report_to=["tensorboard"], |
| dataloader_num_workers=2, |
| seed=args.seed, |
| label_smoothing_factor=0.0, |
| max_grad_norm=1.0, |
| save_safetensors=False, |
| ) |
|
|
| callbacks: List[TrainerCallback] = [ |
| RouterAndSmokeCallback(tok), |
| RouterScheduleCallback( |
| tau_start=args.router_tau_start, |
| tau_end=args.router_tau_end, |
| aux_start=args.router_aux_start, |
| aux_end=args.router_aux_end, |
| total_steps=args.max_steps, |
| tau_freeze_steps=args.router_tau_freeze_steps, |
| force_prob=args.router_force_prob, |
| force_warmup_steps=args.router_force_warmup_steps, |
| ), |
| ] |
|
|
| trainer = VeronicaTrainer( |
| model=model, |
| args=train_args, |
| train_dataset=train_ds, |
| eval_dataset=val_ds, |
| tokenizer=tok, |
| data_collator=collator, |
| callbacks=callbacks, |
| label_smoothing=args.label_smoothing, |
| rep_alpha=args.rep_alpha, |
| ) |
|
|
| |
| effective_vocab = len(tok) |
| emb = model.get_input_embeddings().weight |
| head = model.lm_head.weight |
| assert emb.shape[0] == effective_vocab == head.shape[0], "Mismatch vocab/emb/lm_head" |
|
|
| |
| trainer.train(resume_from_checkpoint=resume_ckpt) |
| trainer.save_state() |
| trainer.save_model(args.output_dir) |
| tok.save_pretrained(args.output_dir) |
| with open(os.path.join(args.output_dir, "config.final.json"), "w", encoding="utf-8") as f: |
| json.dump(model.config.to_dict(), f, indent=2) |
| print("✅ Pretraining completed/saved.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|