omrisap's picture
Update README.md
3aff42c verified
metadata
library_name: transformers
license: apache-2.0
datasets:
  - AI-MO/NuminaMath-CoT
base_model:
  - Qwen/Qwen2.5-Math-1.5B
pipeline_tag: text-generation
tags:
  - reinforcement-learning
  - grpo
  - hierarchical
  - reasoning
  - math
  - tree-based
model_name: TreeRPO-Qwen2.5-Math-1.5B

TreeRPO-Qwen2.5-Math-1.5B

Summary:
A 1.5B parameter math reasoning model fine-tuned with TreeRPO, a hierarchical extension of GRPO that assigns rewards to “thought” nodes (not just full completions). Achieves higher GSM8K accuracy with just ~10K supervised + RL examples and no reward model.

🔎 Full write-up (method, math, analysis):
TreeRPO: Hierarchical Credit Assignment for Reasoning in Language Models


Model Details

  • Base model: Qwen/Qwen2.5-Math-1.5B
  • Method: TreeRPO (tree-structured GRPO;)
  • Reward signal: Deterministic exact-match checker (binary). Interior node rewards = mean descendant leaf rewards.
  • Domain: Grade-school and intermediate math word problems (GSM8K style)

Intended Use

Research on hierarchical RL for reasoning; math tutoring (with human oversight); or as a research baseline for deterministic pass/fail domains (potential to extend to code with unit tests).

Not intended for:
Open-ended or unsafe dialog, general factual QA, or high-stakes applications.


Evaluation (GSM8K Test Set, 1,319 problems)

Model Greedy (%) Maj@8 (%) Notes
Qwen2.5-Math-1.5B-Instruct 84.8 89.5 Reported settings
Qwen2.5-Math-1.5B-TreeRPO 86.4 89.6 Same decoding (temp 0 / (0.7, 0.8))
  • Greedy: temperature = 0 (deterministic)
  • Maj@8: 8 completions (temperature 0.7, top-p 0.8); majority vote on final boxed answer

How to Use

If your Transformers version supports chat templates (≥4.38), use:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "omrisap/TreeRPO-Qwen2.5-Math-1.5B"
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

messages = [
    {"role": "system", "content": "You are a helpful math reasoning assistant. Provide step-by-step reasoning and put the final answer in \\boxed{}."},
    {"role": "user",   "content": "If 3x + 5 = 17, what is x?"}
]

prompt_text = tok.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

inputs = tok(prompt_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0)
print(tok.decode(outputs[0], skip_special_tokens=True))