omrisap commited on
Commit
76767fe
·
verified ·
1 Parent(s): 4901dec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +24 -16
README.md CHANGED
@@ -18,35 +18,44 @@ model_name: TreeRPO-Qwen2.5-Math-1.5B
18
 
19
  # TreeRPO-Qwen2.5-Math-1.5B
20
 
21
- **Short summary:** A 1.5B parameter math reasoning model fine-tuned with *TreeRPO*, a hierarchical extension of GRPO that assigns rewards to “thought” nodes instead of whole sequences—achieving higher GSM8K accuracy with **~10K total** supervised + RL examples and **no reward model**.
 
22
 
23
  🔎 **Full write-up (method, math, analysis):**
24
- https://omrisapir.substack.com/publish/post/167273414
 
 
25
 
26
  ## Model Details
27
- - **Base model:** `Qwen/Qwen2.5-Math-1.5B`
28
- - **Method:** TreeRPO
29
- - **Reward signal:** Deterministic exact-match checker (binary). Interior node reward = average of descendant leaf rewards.
30
- - **Intended domain:** Grade-school & intermediate math word problems (GSM8K style)
31
 
32
  ## Intended Use
33
- Research on hierarchical RL for reasoning; math tutoring prototypes with human oversight; experimentation in deterministic pass/fail domains (e.g., potential extension to code with unit tests).
34
 
35
- **Not intended for:** Open-ended unsafe dialogue, factual QA outside math, high‑stakes decision making.
36
- single 48GB GPU (~18h RL phase).
 
 
37
 
38
  ## Evaluation (GSM8K Test Set, 1,319 problems)
39
 
40
- | Model | Greedy (%) | Maj@8 (%) | Notes |
41
- |-------|------------|-----------|-------|
42
- | Qwen2.5-Math-1.5B-Instruct | 84.8 | 89.5 | Reported settings |
43
- | **Qwen2.5-Math-1.5B-TreeRPO** | **86.4** | **89.6** | Same decoding (temp 0 / (0.7, top-p 0.8)) |
44
 
45
- - **Greedy** = temperature 0 deterministic decoding.
46
- - **Maj@8** = 8 sampled completions (temp 0.7, top-p 0.8) majority vote on final boxed answer. Ties / missing boxed answer → incorrect. Single-run numbers (no multi-seed variance).
 
 
47
 
48
  ## How to Use
49
 
 
 
50
  ```python
51
  from transformers import AutoModelForCausalLM, AutoTokenizer
52
  import torch
@@ -69,4 +78,3 @@ prompt_text = tok.apply_chat_template(
69
  inputs = tok(prompt_text, return_tensors="pt").to(model.device)
70
  outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0)
71
  print(tok.decode(outputs[0], skip_special_tokens=True))
72
-
 
18
 
19
  # TreeRPO-Qwen2.5-Math-1.5B
20
 
21
+ **Summary:**
22
+ A 1.5B parameter math reasoning model fine-tuned with **TreeRPO**, a hierarchical extension of GRPO that assigns rewards to “thought” nodes (not just full completions). Achieves higher GSM8K accuracy with just ~10K supervised + RL examples and **no reward model**.
23
 
24
  🔎 **Full write-up (method, math, analysis):**
25
+ [TreeRPO: Hierarchical Credit Assignment for Data-Efficient Math Reasoning](https://omrisapir.substack.com/publish/post/167273414)
26
+
27
+ ---
28
 
29
  ## Model Details
30
+ - **Base model:** [`Qwen/Qwen2.5-Math-1.5B`](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B)
31
+ - **Method:** TreeRPO (tree-structured GRPO; up to depth 7; branching by entropy & length)
32
+ - **Reward signal:** Deterministic exact-match checker (binary). Interior node rewards = mean descendant leaf rewards.
33
+ - **Domain:** Grade-school and intermediate math word problems (GSM8K style)
34
 
35
  ## Intended Use
36
+ Research on hierarchical RL for reasoning; math tutoring (with human oversight); or as a research baseline for deterministic pass/fail domains (potential to extend to code with unit tests).
37
 
38
+ **Not intended for:**
39
+ Open-ended or unsafe dialog, general factual QA, or high-stakes applications.
40
+
41
+ ---
42
 
43
  ## Evaluation (GSM8K Test Set, 1,319 problems)
44
 
45
+ | Model | Greedy (%) | Maj@8 (%) | Notes |
46
+ |---------------------------------|------------|-----------|--------------------------------------|
47
+ | Qwen2.5-Math-1.5B-Instruct | 84.8 | 89.5 | Reported settings |
48
+ | **TreeRPO-Qwen2.5-Math-1.5B** | **86.4** | **89.6** | Same decoding (temp 0 / (0.7, 0.8)) |
49
 
50
+ - **Greedy:** temperature = 0 (deterministic)
51
+ - **Maj@8:** 8 completions (temperature 0.7, top-p 0.8); majority vote on final boxed answer
52
+
53
+ ---
54
 
55
  ## How to Use
56
 
57
+ If your Transformers version supports chat templates (≥4.38), use:
58
+
59
  ```python
60
  from transformers import AutoModelForCausalLM, AutoTokenizer
61
  import torch
 
78
  inputs = tok(prompt_text, return_tensors="pt").to(model.device)
79
  outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0)
80
  print(tok.decode(outputs[0], skip_special_tokens=True))