We train lightweight predictors to estimate how much reasoning each query needs before generation begins, then use a greedy algorithm to allocate a fixed token budget where it matters most. The result: up to +7.9% accuracy at identical compute cost, closing over half the gap to an oracle with perfect foresight.
Chain-of-thought (CoT) reasoning has become the dominant paradigm for getting LLMs to solve complex problems. The idea is that instead of jumping straight to an answer, the model "thinks out loud," generating intermediate reasoning steps that guide it toward the correct solution. This approach has really remarkable capabilities in logic/math/coding and multi-step problem solving.
However, in practice, we typically allocate a fixed token budget to every query—say, 256 tokens of reasoning per problem. This uniform approach creates two distinct failure modes:
Easy queries waste compute. Consider a simple arithmetic problem like "What is 15% of 80?" A competent model can solve this in 30-40 tokens. Giving it 256 tokens means we're burning through unnecessary computation, adding latency, and inflating API costs—all for no accuracy benefit.
Hard queries get starved. Meanwhile, a multi-step word problem involving rates, ratios, and unit conversions might genuinely need 200+ tokens to work through correctly. If we've set our budget at 128 tokens to save costs, we're forcing the model to truncate its reasoning mid-thought, almost guaranteeing a wrong answer.
The waste is staggering at scale. Production LLM services handle millions of queries daily. If even 30% of those queries could be solved with half the typical token budget, the savings in compute, latency, and cost would be enormous. Conversely, if we're systematically under-allocating to hard problems, we're leaving accuracy on the table.
This raises a natural question. What if we could predict, before generating a single token, how much reasoning each query actually needs?
The core insight behind predictive scheduling is that the information needed to estimate a query's difficulty is already present in the model—we just need to extract it. When an LLM processes a question, its hidden states encode rich information about the problem's structure and likely solution path. We can train lightweight predictors to read these signals and forecast how much reasoning will be required.
Our framework operates in three stages. First, we run a fast forward pass through the model to extract hidden states from the query. Second, a trained predictor maps these hidden states to an estimate of the query's "reasoning requirements." Third, a scheduling algorithm uses these estimates to distribute a fixed total token budget across all queries in a batch, giving more tokens to queries that need them and fewer to queries that don't.
To train our predictors, we first need ground truth labels that capture how a query's success probability changes with token budget. We construct these labels empirically using what we call "early-stopping probability vectors."
For each query in our training set, we generate 100 independent reasoning traces using temperature sampling. Within each trace, we insert a forced-answer probe at regular intervals (every 16 tokens, up to 256 tokens). The probe is a string that forces the model to commit to a final answer at that point:
"Oh, I suddenly got the answer to the whole problem, Final Answer\n\n\[ \boxed{"
By comparing these forced answers against the ground truth, we can compute the empirical probability of success at each checkpoint. The result is a 16-dimensional vector for each query, where entry k represents P(correct | k×16 tokens of reasoning).
These vectors reveal striking patterns. Easy problems show a rapid rise in success probability within the first few checkpoints, then plateau near 1.0. Hard problems remain near zero until much higher token budgets, and may never reach high accuracy even at maximum budget. Medium problems fall somewhere in between, with gradual improvement across the token range.
Our first prediction approach extracts hidden states from intermediate transformer layers and trains a simple MLP to predict the early-stopping probability vector. The architecture is deliberately lightweight: two fully-connected layers with 256 hidden units, ReLU activations, and a final sigmoid to constrain outputs to [0,1]. We train separate MLPs for each transformer layer to systematically evaluate which layers encode the most predictive signal.
Why hidden states? When a transformer processes a query, each layer builds increasingly abstract representations of the input. Early layers capture surface-level features like syntax and token identity. Middle layers integrate these into semantic representations that capture meaning and relationships. Late layers specialize toward the specific output format. Our hypothesis was that middle layers—which balance syntactic and semantic information—would provide the strongest signal for reasoning difficulty prediction.
Our second approach takes a more integrated route: we fine-tune the base LLM itself (with LoRA) to predict early-stopping probabilities directly from the raw question text. This method uses rank-16 LoRA adapters on the query and value projections, followed by a regression head that maps the final hidden state to a 16-dimensional probability vector.
The motivation here is that predicting reasoning complexity is fundamentally a linguistic task. Certain phrases and structures—"prove that," "derive," "show that"—tend to signal harder problems, while "calculate," "find," or "what is" often indicate more straightforward computations. A fine-tuned language model should be able to learn these patterns directly from text, without requiring explicit hidden state extraction.
One of our most striking findings concerns where in the transformer the predictive signal lives. We trained identical MLP predictors on hidden states from each of the 28 layers of DeepSeek-R1-Distill-Qwen-1.5B and measured their correlation with ground truth early-stopping probabilities.
The results show a clear inverted U-shape across model depth. Early layers (1-6) achieve correlations below 0.6—they capture syntactic features but lack the semantic depth needed for difficulty assessment. Late layers (21-28) show similar limitations, having specialized too heavily toward output generation to retain general reasoning signals. But middle layers (12-17) significantly outperform both extremes, with layer 16 achieving the highest test correlation of 0.742.
We also analyzed the efficiency of each layer using a correlation-to-loss ratio metric, which measures predictive power per unit of error. This analysis further confirms the effectiveness of middle layers, which achieve 15-20% higher efficiency than early or late layers. The implication is clear: if you're building a reasoning-length predictor, target the middle of the network.
Once we have predicted early-stopping probabilities for each query in a batch, we need an algorithm to translate these predictions into token allocations. The goal is to maximize expected accuracy across the batch subject to a fixed total token budget.
We adopt a greedy approach that iteratively allocates tokens to the query with the highest marginal gain. The algorithm begins by assigning each query a minimum allocation (16 tokens). Then, at each step, it computes the expected accuracy improvement from giving one additional token window to each query and allocates to the query with the highest gain. This process continues until the budget is exhausted or no query would benefit from additional tokens.
The ComputeGains function returns the marginal expected-accuracy improvement from adding one more window of W tokens to each query. For query i currently at checkpoint k, the gain is simply
P[i, k+1] - P[i, k]: the predicted increase in success probability from extending the reasoning budget.
This greedy algorithm has an intuitive interpretation. At low total budgets, it prioritizes easy queries because they offer the highest marginal gains—a small token investment yields a large jump in success probability. As the budget increases and easy queries saturate, the algorithm shifts to medium queries, then finally to hard queries that require substantial investment before showing any improvement.
Predicting full 16-dimensional probability vectors may be overkill for practical allocation. A simpler approach is to classify each query into one of three difficulty categories—easy, medium, or hard—and allocate tokens based on category membership. This reduces the prediction problem from regression to classification, which is often more robust to noise.
We define difficulty categories using the 256-token success probability from our training data. Queries in the bottom 20th percentile (success probability ≤ 0.18) are labeled "hard"; those in the top 20th percentile (≥ 0.84) are labeled "easy"; the remainder are "medium." This stratification produces balanced classes and captures the qualitative distinction between problems that are trivially solvable, genuinely challenging, or somewhere in between.
| Split | Total | Easy | Medium | Hard |
|---|---|---|---|---|
| Train | 7,450 | 1,506 | 4,437 | 1,507 |
| Test | 1,294 | 271 | 760 | 263 |
Difficulty thresholds derived from training set: p₂₀ = 0.18, p₈₀ = 0.84 based on 256-token accuracy.
To validate this categorization, we computed average early-stopping curves for each difficulty class. The results confirm that our categories capture fundamentally different reasoning patterns: easy problems show rapid improvement and early plateau; medium problems improve gradually across the token range; hard problems remain near zero until the highest budgets. Crucially, these patterns are consistent between train and test sets, indicating that the categorization generalizes.
We evaluated two approaches for difficulty classification. The first uses few-shot prompting with a commercial LLM (o4-mini), providing three labeled examples and asking the model to classify new queries. The second fine-tunes our base model with LoRA to predict difficulty directly from question text.
The results strongly favor fine-tuning. The LoRA classifier achieves 66.3% test accuracy, compared to just 41.6% for the few-shot approach—a 24.7 percentage point improvement. Both methods struggle most with the medium category, which makes sense: the boundary between "medium" and "easy" (or "hard") is inherently fuzzy, and many problems could reasonably be classified either way.
Given difficulty predictions, the allocation algorithm is straightforward: we precompute optimal per-category budgets using the training set curves, then assign each query the budget corresponding to its predicted category. This approach is simpler than full greedy allocation and, surprisingly, often performs better in practice.
We evaluated our predictive scheduling approaches on the GSM8K test set, comparing against a uniform-allocation baseline and an oracle that uses ground-truth early-stopping probabilities. The results demonstrate that adaptive allocation consistently outperforms uniform budgeting, with difficulty-based allocation proving most effective.
At an average budget of 96 tokens per query, difficulty-based allocation achieves up to 7.9 percentage points higher accuracy than uniform allocation—at identical total compute cost. This improvement closes over 50% of the gap to the oracle upper bound, demonstrating that our predictors capture meaningful signal about reasoning requirements.
Interestingly, the size-based approach (using MLP predictions of continuous probability vectors) shows mixed results. It outperforms uniform allocation at low budgets (16-96 tokens) where constraints are tight and adaptive allocation matters most. But at higher budgets, prediction errors become more consequential, and size-based allocation actually underperforms the baseline. Difficulty-based allocation, with its more robust categorical predictions, avoids this pitfall and maintains consistent gains across all budget levels.
The core greedy allocation algorithm is simple to implement:
import numpy as np
def greedy_allocate(prob_vectors, budget_per_query, window=16):
"""
Allocate tokens greedily based on predicted early-stopping probabilities.
Args:
prob_vectors: (N, K) array of P(correct | budget_k) for each query
budget_per_query: average token budget per query
window: allocation granularity (default 16 tokens)
Returns:
allocations: (N,) array of per-query token budgets
"""
n_queries, n_checkpoints = prob_vectors.shape
total_budget = budget_per_query * n_queries
# Initialize with minimum allocation
allocations = np.full(n_queries, window)
remaining = total_budget - allocations.sum()
# Current checkpoint index for each query
current_idx = np.zeros(n_queries, dtype=int)
while remaining >= window:
# Compute marginal gains
gains = np.full(n_queries, -np.inf)
for i in range(n_queries):
idx = current_idx[i]
if idx + 1 < n_checkpoints:
gains[i] = prob_vectors[i, idx + 1] - prob_vectors[i, idx]
# Check if any positive gain exists
if gains.max() <= 0:
break
# Allocate to query with highest marginal gain
best_query = np.argmax(gains)
allocations[best_query] += window
current_idx[best_query] += 1
remaining -= window
return allocations
def difficulty_allocate(difficulties, proportions, budget_per_query):
"""
Allocate tokens based on difficulty classification.
Args:
difficulties: (N,) array of difficulty labels (0=easy, 1=medium, 2=hard)
proportions: (3,) array of [p_easy, p_medium, p_hard]
budget_per_query: average token budget
Returns:
allocations: (N,) array of per-query token budgets
"""
# Pre-computed optimal budgets per difficulty (from training set curves)
optimal_budgets = get_optimal_budgets(proportions, budget_per_query)
allocations = np.zeros(len(difficulties))
for i, diff in enumerate(difficulties):
allocations[i] = optimal_budgets[diff]
return allocations
Our work has several limitations that point toward future research directions. First, we evaluated exclusively on GSM8K, a grade-school math benchmark. While our methods should generalize to other reasoning domains—code generation, logical deduction, multi-hop QA—this remains to be validated empirically. Different domains may exhibit different difficulty signatures that require domain-specific predictors.
Second, our best predictor achieves r = 0.742, leaving substantial room for improvement before reaching oracle performance. Hybrid approaches that combine hidden-state features with linguistic pattern recognition may close this gap. Alternatively, ensembling multiple layer-specific predictors could capture complementary signals.
Third, we focused on single-trace generation, but many production systems use multi-trace aggregation methods like self-consistency or tree-of-thoughts. Extending predictive scheduling to these settings—predicting not just trace length but optimal number of traces—is a natural next step.
Finally, our current predictors provide point estimates without uncertainty quantification. At higher token budgets, where prediction errors become more consequential, calibrated uncertainty estimates could help the allocator hedge against misallocation. Queries with high predictive uncertainty might receive moderate allocations rather than extreme ones.
@article{brown2025predictive,
title={Predictive Scheduling for Efficient Inference-Time
Reasoning in Large Language Models},
author={Brown, Katrina and Muppidi, Aneesh and Shahout, Rana},
journal={arXiv preprint arXiv:XXXX.XXXXX},
year={2025}
}