File size: 2,943 Bytes
cdd261f
 
f796f74
 
 
 
 
 
 
 
 
 
 
 
 
 
cdd261f
 
f796f74
cdd261f
76767fe
 
cdd261f
f796f74
f94103e
76767fe
 
cdd261f
 
76767fe
3aff42c
76767fe
 
cdd261f
f796f74
76767fe
cdd261f
76767fe
 
 
 
cdd261f
f796f74
cdd261f
76767fe
 
 
3aff42c
cdd261f
76767fe
 
 
 
cdd261f
f796f74
cdd261f
76767fe
 
f796f74
 
 
cdd261f
4901dec
f796f74
 
cdd261f
7a234fc
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
---
library_name: transformers
license: apache-2.0
datasets:
- AI-MO/NuminaMath-CoT
base_model:
- Qwen/Qwen2.5-Math-1.5B
pipeline_tag: text-generation
tags:
- reinforcement-learning
- grpo
- hierarchical
- reasoning
- math
- tree-based
model_name: TreeRPO-Qwen2.5-Math-1.5B
---

# TreeRPO-Qwen2.5-Math-1.5B

**Summary:**  
A 1.5B parameter math reasoning model fine-tuned with **TreeRPO**, a hierarchical extension of GRPO that assigns rewards to “thought” nodes (not just full completions). Achieves higher GSM8K accuracy with just ~10K supervised + RL examples and **no reward model**.

🔎 **Full write-up (method, math, analysis):**  
[TreeRPO: Hierarchical Credit Assignment for Reasoning in Language Models](https://omrisapir.substack.com/publish/post/167273414)

---

## Model Details
- **Base model:** [`Qwen/Qwen2.5-Math-1.5B`](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B)
- **Method:** TreeRPO (tree-structured GRPO;)
- **Reward signal:** Deterministic exact-match checker (binary). Interior node rewards = mean descendant leaf rewards.
- **Domain:** Grade-school and intermediate math word problems (GSM8K style)

## Intended Use
Research on hierarchical RL for reasoning; math tutoring (with human oversight); or as a research baseline for deterministic pass/fail domains (potential to extend to code with unit tests).

**Not intended for:**  
Open-ended or unsafe dialog, general factual QA, or high-stakes applications.

---

## Evaluation (GSM8K Test Set, 1,319 problems)

| Model                          | Greedy (%) | Maj@8 (%) | Notes                                |
|---------------------------------|------------|-----------|--------------------------------------|
| Qwen2.5-Math-1.5B-Instruct      | 84.8       | 89.5      | Reported settings                    |
| **Qwen2.5-Math-1.5B-TreeRPO**   | **86.4**   | **89.6**  | Same decoding (temp 0 / (0.7, 0.8))  |

- **Greedy:** temperature = 0 (deterministic)
- **Maj@8:** 8 completions (temperature 0.7, top-p 0.8); majority vote on final boxed answer

---

## How to Use

If your Transformers version supports chat templates (≥4.38), use:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "omrisap/TreeRPO-Qwen2.5-Math-1.5B"
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

messages = [
    {"role": "system", "content": "You are a helpful math reasoning assistant. Provide step-by-step reasoning and put the final answer in \\boxed{}."},
    {"role": "user",   "content": "If 3x + 5 = 17, what is x?"}
]

prompt_text = tok.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

inputs = tok(prompt_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0)
print(tok.decode(outputs[0], skip_special_tokens=True))