Update README.md
Browse files
README.md
CHANGED
|
@@ -40,7 +40,7 @@ Research on hierarchical RL for reasoning; math tutoring prototypes with human o
|
|
| 40 |
| Model | Greedy (%) | Maj@8 (%) | Notes |
|
| 41 |
|-------|------------|-----------|-------|
|
| 42 |
| Qwen2.5-Math-1.5B-Instruct | 84.8 | 89.5 | Reported settings |
|
| 43 |
-
| **
|
| 44 |
|
| 45 |
- **Greedy** = temperature 0 deterministic decoding.
|
| 46 |
- **Maj@8** = 8 sampled completions (temp 0.7, top-p 0.8) majority vote on final boxed answer. Ties / missing boxed answer → incorrect. Single-run numbers (no multi-seed variance).
|
|
@@ -55,7 +55,18 @@ model_name = "your-namespace/TreeRPO-Qwen2.5-Math-1.5B"
|
|
| 55 |
tok = AutoTokenizer.from_pretrained(model_name)
|
| 56 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
| Model | Greedy (%) | Maj@8 (%) | Notes |
|
| 41 |
|-------|------------|-----------|-------|
|
| 42 |
| Qwen2.5-Math-1.5B-Instruct | 84.8 | 89.5 | Reported settings |
|
| 43 |
+
| **Qwen2.5-Math-1.5B-TreeRPO** | **86.4** | **89.6** | Same decoding (temp 0 / (0.7, top-p 0.8)) |
|
| 44 |
|
| 45 |
- **Greedy** = temperature 0 deterministic decoding.
|
| 46 |
- **Maj@8** = 8 sampled completions (temp 0.7, top-p 0.8) majority vote on final boxed answer. Ties / missing boxed answer → incorrect. Single-run numbers (no multi-seed variance).
|
|
|
|
| 55 |
tok = AutoTokenizer.from_pretrained(model_name)
|
| 56 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
| 57 |
|
| 58 |
+
messages = [
|
| 59 |
+
{"role": "system", "content": "You are a helpful math reasoning assistant. Provide step-by-step reasoning and put the final answer in \\boxed{}."},
|
| 60 |
+
{"role": "user", "content": "If 3x + 5 = 17, what is x?"}
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
prompt_text = tok.apply_chat_template(
|
| 64 |
+
messages,
|
| 65 |
+
tokenize=False,
|
| 66 |
+
add_generation_prompt=True
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
inputs = tok(prompt_text, return_tensors="pt").to(model.device)
|
| 70 |
+
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0)
|
| 71 |
+
print(tok.decode(outputs[0], skip_special_tokens=True))
|
| 72 |
+
|