File size: 2,943 Bytes
cdd261f f796f74 cdd261f f796f74 cdd261f 76767fe cdd261f f796f74 f94103e 76767fe cdd261f 76767fe 3aff42c 76767fe cdd261f f796f74 76767fe cdd261f 76767fe cdd261f f796f74 cdd261f 76767fe 3aff42c cdd261f 76767fe cdd261f f796f74 cdd261f 76767fe f796f74 cdd261f 4901dec f796f74 cdd261f 7a234fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
---
library_name: transformers
license: apache-2.0
datasets:
- AI-MO/NuminaMath-CoT
base_model:
- Qwen/Qwen2.5-Math-1.5B
pipeline_tag: text-generation
tags:
- reinforcement-learning
- grpo
- hierarchical
- reasoning
- math
- tree-based
model_name: TreeRPO-Qwen2.5-Math-1.5B
---
# TreeRPO-Qwen2.5-Math-1.5B
**Summary:**
A 1.5B parameter math reasoning model fine-tuned with **TreeRPO**, a hierarchical extension of GRPO that assigns rewards to “thought” nodes (not just full completions). Achieves higher GSM8K accuracy with just ~10K supervised + RL examples and **no reward model**.
🔎 **Full write-up (method, math, analysis):**
[TreeRPO: Hierarchical Credit Assignment for Reasoning in Language Models](https://omrisapir.substack.com/publish/post/167273414)
---
## Model Details
- **Base model:** [`Qwen/Qwen2.5-Math-1.5B`](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B)
- **Method:** TreeRPO (tree-structured GRPO;)
- **Reward signal:** Deterministic exact-match checker (binary). Interior node rewards = mean descendant leaf rewards.
- **Domain:** Grade-school and intermediate math word problems (GSM8K style)
## Intended Use
Research on hierarchical RL for reasoning; math tutoring (with human oversight); or as a research baseline for deterministic pass/fail domains (potential to extend to code with unit tests).
**Not intended for:**
Open-ended or unsafe dialog, general factual QA, or high-stakes applications.
---
## Evaluation (GSM8K Test Set, 1,319 problems)
| Model | Greedy (%) | Maj@8 (%) | Notes |
|---------------------------------|------------|-----------|--------------------------------------|
| Qwen2.5-Math-1.5B-Instruct | 84.8 | 89.5 | Reported settings |
| **Qwen2.5-Math-1.5B-TreeRPO** | **86.4** | **89.6** | Same decoding (temp 0 / (0.7, 0.8)) |
- **Greedy:** temperature = 0 (deterministic)
- **Maj@8:** 8 completions (temperature 0.7, top-p 0.8); majority vote on final boxed answer
---
## How to Use
If your Transformers version supports chat templates (≥4.38), use:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "omrisap/TreeRPO-Qwen2.5-Math-1.5B"
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
messages = [
{"role": "system", "content": "You are a helpful math reasoning assistant. Provide step-by-step reasoning and put the final answer in \\boxed{}."},
{"role": "user", "content": "If 3x + 5 = 17, what is x?"}
]
prompt_text = tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tok(prompt_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0)
print(tok.decode(outputs[0], skip_special_tokens=True))
|