Vinnnf's picture
Update README.md
d01fce2 verified
|
raw
history blame
3.38 kB
metadata
license: apache-2.0
datasets:
  - agentica-org/DeepScaleR-Preview-Dataset
base_model:
  - Vinnnf/Thinkless-1.5B-Warmup
pipeline_tag: text-generation
library_name: transformers

Thinkless: LLM Learns When to Think

image/png

Introduction

Reasoning Language Models, capable of extended chain-of-thought reasoning, have demonstrated remarkable performance on tasks requiring complex logical inference. However, applying elaborate reasoning for all queries often results in substantial computational inefficiencies, particularly when many problems admit straightforward solutions. This motivates an open question: Can LLMs learn when to think? To answer this, we propose Thinkless, a learnable framework that empowers an LLM to adaptively select between short-form and long-form reasoning based on both task complexity and the model's ability. Thinkless is trained under a reinforcement learning paradigm and employs two control tokens, <short> for concise responses and <think> for detailed reasoning. At the core of our method is a Decoupled Group Relative Policy Optimization (DeGRPO) algorithm, which decomposes the learning objective of hybrid reasoning into two components: (1) a control token loss that governs the selection of the reasoning mode, and (2) a response loss that improves the accuracy of the generated answers. This decoupled formulation enables fine-grained control over the contributions of each objective, stabilizing training and effectively preventing collapse observed in vanilla GRPO. Empirically, on several benchmarks such as Minerva Algebra, MATH-500, and GSM8K, Thinkless is able to reduce the usage of long-chain thinking by 50% - 90%, significantly reducing the computational cost of Reasoning Language Models.

QuickStart

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "checkpoints/trained/paper_final_450"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

instruction = "Please reason step by step, and put your final answer within \\boxed{}."
# prompt = "How many r's are in the word \"strawberry\""
# prompt = "The arithmetic mean of 7, 2, $x$ and 10 is 9. What is the value of $x$?"
prompt = "Let $S$ be the set of points $(a,b)$ with $0 \le a,$ $b \le 1$ such that the equation \[x^4 + ax^3 - bx^2 + ax + 1 = 0\] has at least one real root.  Determine the area of the graph of $S.$"


messages = [
    {"role": "user", "content": f"{instruction}\n{prompt}"},
]

text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=16384
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
num_tokens = len(generated_ids[0])

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

think_mode = ("<think>" in response)

print(text+response)
print(f"\nThink Mode: {think_mode}")
print(f"Number of tokens: {num_tokens}")

Citation

If you find this work helpful, please cite: