ES-FoMo III Workshop, ICML 2025, Vancouver, Canada

Predictive Scheduling for Efficient Inference-Time Reasoning in Large Language Models

Katrina Brown*1, Aneesh Muppidi*1, Rana Shahout2
1Harvard College    2Harvard SEAS    *Equal contribution
Predictive Scheduling Animation
Visual Abstract: Greedy token allocation based on predicted early-stopping probabilities
Given a fixed total token budget across a batch of queries, how do we distribute tokens across queries to maximize accuracy?
TL;DR

We train lightweight predictors to estimate how much reasoning each query needs before generation begins, then use a greedy algorithm to allocate a fixed token budget where it matters most. The result: up to +7.9% accuracy at identical compute cost, closing over half the gap to an oracle with perfect foresight.

One Size Doesn't Fit All!

Chain-of-thought (CoT) reasoning has become the dominant paradigm for getting LLMs to solve complex problems. The idea is that instead of jumping straight to an answer, the model "thinks out loud," generating intermediate reasoning steps that guide it toward the correct solution. This approach has really remarkable capabilities in logic/math/coding and multi-step problem solving.

However, in practice, we typically allocate a fixed token budget to every query—say, 256 tokens of reasoning per problem. This uniform approach creates two distinct failure modes:

Easy queries waste compute. Consider a simple arithmetic problem like "What is 15% of 80?" A competent model can solve this in 30-40 tokens. Giving it 256 tokens means we're burning through unnecessary computation, adding latency, and inflating API costs—all for no accuracy benefit.

Hard queries get starved. Meanwhile, a multi-step word problem involving rates, ratios, and unit conversions might genuinely need 200+ tokens to work through correctly. If we've set our budget at 128 tokens to save costs, we're forcing the model to truncate its reasoning mid-thought, almost guaranteeing a wrong answer.

The waste is staggering at scale. Production LLM services handle millions of queries daily. If even 30% of those queries could be solved with half the typical token budget, the savings in compute, latency, and cost would be enormous. Conversely, if we're systematically under-allocating to hard problems, we're leaving accuracy on the table.

This raises a natural question. What if we could predict, before generating a single token, how much reasoning each query actually needs?

So...Predict First, Then Allocate

The core insight behind predictive scheduling is that the information needed to estimate a query's difficulty is already present in the model—we just need to extract it. When an LLM processes a question, its hidden states encode rich information about the problem's structure and likely solution path. We can train lightweight predictors to read these signals and forecast how much reasoning will be required.

Our framework operates in three stages. First, we run a fast forward pass through the model to extract hidden states from the query. Second, a trained predictor maps these hidden states to an estimate of the query's "reasoning requirements." Third, a scheduling algorithm uses these estimates to distribute a fixed total token budget across all queries in a batch, giving more tokens to queries that need them and fewer to queries that don't.

Early-Stopping Probability Vectors

To train our predictors, we first need ground truth labels that capture how a query's success probability changes with token budget. We construct these labels empirically using what we call "early-stopping probability vectors."

For each query in our training set, we generate 100 independent reasoning traces using temperature sampling. Within each trace, we insert a forced-answer probe at regular intervals (every 16 tokens, up to 256 tokens). The probe is a string that forces the model to commit to a final answer at that point:

"Oh, I suddenly got the answer to the whole problem, Final Answer\n\n\[ \boxed{"

By comparing these forced answers against the ground truth, we can compute the empirical probability of success at each checkpoint. The result is a 16-dimensional vector for each query, where entry k represents P(correct | k×16 tokens of reasoning).

These vectors reveal striking patterns. Easy problems show a rapid rise in success probability within the first few checkpoints, then plateau near 1.0. Hard problems remain near zero until much higher token budgets, and may never reach high accuracy even at maximum budget. Medium problems fall somewhere in between, with gradual improvement across the token range.

Early stopping KDE distribution
Figure 1. Distribution of correct-answer probabilities across token budgets. At low budgets, success probabilities cluster near 0; at high budgets, they cluster near 1. The bimodal structure suggests that problems tend to be either "solved" or "not solved" at each budget level, with relatively few queries in intermediate states.

Approach 1, MLP on Hidden States

Our first prediction approach extracts hidden states from intermediate transformer layers and trains a simple MLP to predict the early-stopping probability vector. The architecture is deliberately lightweight: two fully-connected layers with 256 hidden units, ReLU activations, and a final sigmoid to constrain outputs to [0,1]. We train separate MLPs for each transformer layer to systematically evaluate which layers encode the most predictive signal.

Why hidden states? When a transformer processes a query, each layer builds increasingly abstract representations of the input. Early layers capture surface-level features like syntax and token identity. Middle layers integrate these into semantic representations that capture meaning and relationships. Late layers specialize toward the specific output format. Our hypothesis was that middle layers—which balance syntactic and semantic information—would provide the strongest signal for reasoning difficulty prediction.

Approach 2, LoRA Fine-Tuned Predictor

Our second approach takes a more integrated route: we fine-tune the base LLM itself (with LoRA) to predict early-stopping probabilities directly from the raw question text. This method uses rank-16 LoRA adapters on the query and value projections, followed by a regression head that maps the final hidden state to a 16-dimensional probability vector.

The motivation here is that predicting reasoning complexity is fundamentally a linguistic task. Certain phrases and structures—"prove that," "derive," "show that"—tend to signal harder problems, while "calculate," "find," or "what is" often indicate more straightforward computations. A fine-tuned language model should be able to learn these patterns directly from text, without requiring explicit hidden state extraction.

Which Transformer Layers Carry the Signal?

One of our most striking findings concerns where in the transformer the predictive signal lives. We trained identical MLP predictors on hidden states from each of the 28 layers of DeepSeek-R1-Distill-Qwen-1.5B and measured their correlation with ground truth early-stopping probabilities.

The results show a clear inverted U-shape across model depth. Early layers (1-6) achieve correlations below 0.6—they capture syntactic features but lack the semantic depth needed for difficulty assessment. Late layers (21-28) show similar limitations, having specialized too heavily toward output generation to retain general reasoning signals. But middle layers (12-17) significantly outperform both extremes, with layer 16 achieving the highest test correlation of 0.742.

Layer-wise correlation analysis
Figure 2. Layer-wise correlation analysis for early-stopping prediction. The top panel shows Pearson correlation coefficients for MLPs trained on different layer features. Middle layers (12-17) substantially outperform early and late layers, confirming our hypothesis about the optimal balance of syntactic and semantic features.

Key Finding: Middle Layers Are Most Informative

Layer 16 of DeepSeek-R1-Distill-Qwen-1.5B achieves Pearson r = 0.742 for predicting early-stopping probabilities. Middle layers (12-17) consistently outperform both early layers (which lack semantic depth) and late layers (which over-specialize for output generation). This finding has practical implications: practitioners can extract maximum predictive signal with minimal computational overhead by targeting middle layers specifically.

We also analyzed the efficiency of each layer using a correlation-to-loss ratio metric, which measures predictive power per unit of error. This analysis further confirms the effectiveness of middle layers, which achieve 15-20% higher efficiency than early or late layers. The implication is clear: if you're building a reasoning-length predictor, target the middle of the network.

Loss analysis across layers
Figure 3. Loss analysis across transformer layers. The top-left panel shows test MSE by layer; the top-right shows the MSE-correlation tradeoff; the bottom panel presents correlation-to-loss ratios for top-performing layers. Middle layers achieve the best balance of low loss and high correlation.

The Greedy Allocation Algorithm

Once we have predicted early-stopping probabilities for each query in a batch, we need an algorithm to translate these predictions into token allocations. The goal is to maximize expected accuracy across the batch subject to a fixed total token budget.

We adopt a greedy approach that iteratively allocates tokens to the query with the highest marginal gain. The algorithm begins by assigning each query a minimum allocation (16 tokens). Then, at each step, it computes the expected accuracy improvement from giving one additional token window to each query and allocates to the query with the highest gain. This process continues until the budget is exhausted or no query would benefit from additional tokens.

Algorithm 1: Greedy Token Allocation
1Require: Q (queries), B (budget), W (window=16), P (probability vectors)
2allocations ← [W] × |Q| // minimum allocation
3remaining ← B·|Q| - Σ allocations
4while remaining ≥ W do
5  gains ← ComputeGains(P, allocations, W)
6  if max(gains) ≤ 0 then break
7  i* ← argmax(gains)
8  allocations[i*] += W
9  remaining -= W
10return allocations

The ComputeGains function returns the marginal expected-accuracy improvement from adding one more window of W tokens to each query. For query i currently at checkpoint k, the gain is simply P[i, k+1] - P[i, k]: the predicted increase in success probability from extending the reasoning budget.

This greedy algorithm has an intuitive interpretation. At low total budgets, it prioritizes easy queries because they offer the highest marginal gains—a small token investment yields a large jump in success probability. As the budget increases and easy queries saturate, the algorithm shifts to medium queries, then finally to hard queries that require substantial investment before showing any improvement.

Greedy allocation visualization
Figure 4. Visualization of greedy allocation in action. The left panel shows the early-stopping probability heatmap for eight queries (darker = higher P(correct)). The right panel shows how tokens are progressively allocated across queries. Blue outlines indicate which query receives additional tokens at each step. Note how the algorithm initially favors easy queries (top rows) before shifting budget to harder queries (bottom rows).

A Simpler Alternative: Difficulty-Based Allocation

Predicting full 16-dimensional probability vectors may be overkill for practical allocation. A simpler approach is to classify each query into one of three difficulty categories—easy, medium, or hard—and allocate tokens based on category membership. This reduces the prediction problem from regression to classification, which is often more robust to noise.

We define difficulty categories using the 256-token success probability from our training data. Queries in the bottom 20th percentile (success probability ≤ 0.18) are labeled "hard"; those in the top 20th percentile (≥ 0.84) are labeled "easy"; the remainder are "medium." This stratification produces balanced classes and captures the qualitative distinction between problems that are trivially solvable, genuinely challenging, or somewhere in between.

Split Total Easy Medium Hard
Train 7,450 1,506 4,437 1,507
Test 1,294 271 760 263

Difficulty thresholds derived from training set: p₂₀ = 0.18, p₈₀ = 0.84 based on 256-token accuracy.

To validate this categorization, we computed average early-stopping curves for each difficulty class. The results confirm that our categories capture fundamentally different reasoning patterns: easy problems show rapid improvement and early plateau; medium problems improve gradually across the token range; hard problems remain near zero until the highest budgets. Crucially, these patterns are consistent between train and test sets, indicating that the categorization generalizes.

Early stopping by difficulty class
Figure 5. Average early-stopping probability curves by difficulty class. Blue lines show training set averages; red lines show test set averages. The close tracking between train and test within each category validates our difficulty stratification approach.

Difficulty Classifiers: Few-Shot vs. Fine-Tuned

We evaluated two approaches for difficulty classification. The first uses few-shot prompting with a commercial LLM (o4-mini), providing three labeled examples and asking the model to classify new queries. The second fine-tunes our base model with LoRA to predict difficulty directly from question text.

The results strongly favor fine-tuning. The LoRA classifier achieves 66.3% test accuracy, compared to just 41.6% for the few-shot approach—a 24.7 percentage point improvement. Both methods struggle most with the medium category, which makes sense: the boundary between "medium" and "easy" (or "hard") is inherently fuzzy, and many problems could reasonably be classified either way.

Confusion matrices
Figure 6. Confusion matrices for difficulty classification. The few-shot approach (left) achieves only 41.6% accuracy, with substantial confusion between all categories. The LoRA fine-tuned classifier (right) reaches 66.3% accuracy, with much stronger diagonal elements. Both methods show the expected pattern of medium-category confusion.

Insight: Coarse-Grained Classification Outperforms Fine-Grained Regression

This result offers an important methodological lesson. Our LoRA-based early-stopping regressor achieved only r = 0.444 when predicting continuous probability vectors from text. But the same architecture achieves 66.3% accuracy when predicting discrete difficulty categories. Linguistic features in problem statements do correlate with reasoning complexity—but this relationship is more robustly captured through coarse categorization than fine-grained continuous prediction.

Given difficulty predictions, the allocation algorithm is straightforward: we precompute optimal per-category budgets using the training set curves, then assign each query the budget corresponding to its predicted category. This approach is simpler than full greedy allocation and, surprisingly, often performs better in practice.

Results: Putting It All Together

We evaluated our predictive scheduling approaches on the GSM8K test set, comparing against a uniform-allocation baseline and an oracle that uses ground-truth early-stopping probabilities. The results demonstrate that adaptive allocation consistently outperforms uniform budgeting, with difficulty-based allocation proving most effective.

GSM8K accuracy vs budget comparison
Figure 7. Accuracy versus token budget for different allocation strategies. Difficulty-based allocation using LoRA predictions (green) consistently outperforms both size-based allocation using MLP layer 16 predictions (blue) and the uniform baseline (gray) across all budget levels. The gap is most pronounced at intermediate budgets where allocation decisions have the greatest impact.

At an average budget of 96 tokens per query, difficulty-based allocation achieves up to 7.9 percentage points higher accuracy than uniform allocation—at identical total compute cost. This improvement closes over 50% of the gap to the oracle upper bound, demonstrating that our predictors capture meaningful signal about reasoning requirements.

Interestingly, the size-based approach (using MLP predictions of continuous probability vectors) shows mixed results. It outperforms uniform allocation at low budgets (16-96 tokens) where constraints are tight and adaptive allocation matters most. But at higher budgets, prediction errors become more consequential, and size-based allocation actually underperforms the baseline. Difficulty-based allocation, with its more robust categorical predictions, avoids this pitfall and maintains consistent gains across all budget levels.

Size-based accuracy
Figure 8. Size-based allocation using MLP predictions from layer 16. The adaptive approach outperforms uniform allocation in constrained-budget regimes (16-96 tokens) but falls below baseline at higher budgets. This crossover suggests that continuous prediction errors accumulate as constraints relax.
Difficulty allocation patterns
Figure 9. Optimal token allocation by difficulty class under the oracle greedy algorithm. At low total budgets, nearly all tokens go to easy queries (green) which offer the highest marginal gains. As budget increases, allocation shifts to medium queries (yellow), then finally to hard queries (red) which require substantial investment before showing improvement.

Code

The core greedy allocation algorithm is simple to implement:

import numpy as np

def greedy_allocate(prob_vectors, budget_per_query, window=16):
    """
    Allocate tokens greedily based on predicted early-stopping probabilities.

    Args:
        prob_vectors: (N, K) array of P(correct | budget_k) for each query
        budget_per_query: average token budget per query
        window: allocation granularity (default 16 tokens)

    Returns:
        allocations: (N,) array of per-query token budgets
    """
    n_queries, n_checkpoints = prob_vectors.shape
    total_budget = budget_per_query * n_queries

    # Initialize with minimum allocation
    allocations = np.full(n_queries, window)
    remaining = total_budget - allocations.sum()

    # Current checkpoint index for each query
    current_idx = np.zeros(n_queries, dtype=int)

    while remaining >= window:
        # Compute marginal gains
        gains = np.full(n_queries, -np.inf)
        for i in range(n_queries):
            idx = current_idx[i]
            if idx + 1 < n_checkpoints:
                gains[i] = prob_vectors[i, idx + 1] - prob_vectors[i, idx]

        # Check if any positive gain exists
        if gains.max() <= 0:
            break

        # Allocate to query with highest marginal gain
        best_query = np.argmax(gains)
        allocations[best_query] += window
        current_idx[best_query] += 1
        remaining -= window

    return allocations


def difficulty_allocate(difficulties, proportions, budget_per_query):
    """
    Allocate tokens based on difficulty classification.

    Args:
        difficulties: (N,) array of difficulty labels (0=easy, 1=medium, 2=hard)
        proportions: (3,) array of [p_easy, p_medium, p_hard]
        budget_per_query: average token budget

    Returns:
        allocations: (N,) array of per-query token budgets
    """
    # Pre-computed optimal budgets per difficulty (from training set curves)
    optimal_budgets = get_optimal_budgets(proportions, budget_per_query)

    allocations = np.zeros(len(difficulties))
    for i, diff in enumerate(difficulties):
        allocations[i] = optimal_budgets[diff]

    return allocations

Key Takeaways

Finding 1
Significant accuracy gains at equal compute. Predictive scheduling achieves up to +7.9 percentage points over uniform allocation, closing more than half the gap to an oracle with perfect foresight. These gains come purely from smarter allocation—no additional tokens are consumed.
Finding 2
Middle transformer layers are most informative. Layer 16 achieves r = 0.742 for reasoning-length prediction. This finding provides practical guidance: extract hidden states from layers 12-17 for maximum predictive signal with minimal overhead.
Finding 3
Coarse-grained prediction beats fine-grained. Difficulty-based allocation consistently outperforms continuous size prediction. Discrete categories are more robust to noise and easier to act on, even though they discard information.
Finding 4
Plug-and-play compatibility. Our framework requires no modifications to the underlying LLM. Any model that exposes hidden states can benefit from predictive scheduling, making it immediately applicable to existing deployments.

Limitations and Future Directions

Our work has several limitations that point toward future research directions. First, we evaluated exclusively on GSM8K, a grade-school math benchmark. While our methods should generalize to other reasoning domains—code generation, logical deduction, multi-hop QA—this remains to be validated empirically. Different domains may exhibit different difficulty signatures that require domain-specific predictors.

Second, our best predictor achieves r = 0.742, leaving substantial room for improvement before reaching oracle performance. Hybrid approaches that combine hidden-state features with linguistic pattern recognition may close this gap. Alternatively, ensembling multiple layer-specific predictors could capture complementary signals.

Third, we focused on single-trace generation, but many production systems use multi-trace aggregation methods like self-consistency or tree-of-thoughts. Extending predictive scheduling to these settings—predicting not just trace length but optimal number of traces—is a natural next step.

Finally, our current predictors provide point estimates without uncertainty quantification. At higher token budgets, where prediction errors become more consequential, calibrated uncertainty estimates could help the allocator hedge against misallocation. Queries with high predictive uncertainty might receive moderate allocations rather than extreme ones.

References

  1. Wei et al. "Chain-of-Thought Prompting Elicits Reasoning in Large Language Models" arXiv 2022
  2. Wang et al. "Self-Consistency Improves Chain-of-Thought Reasoning in Language Models" NeurIPS 2022
  3. Yao et al. "Tree of Thoughts: Deliberate Problem Solving with Large Language Models" arXiv 2023
  4. Damani et al. "Learning How Hard to Think: Input-Adaptive Allocation of LM Computation" arXiv 2024
  5. Fu et al. "Efficiently Serving LLM Reasoning Programs with Certaindex" arXiv 2024
  6. Han et al. "Token-Budget-Aware LLM Reasoning" arXiv 2024
  7. Wu et al. "When More is Less: Understanding Chain-of-Thought Length in LLMs" arXiv 2025
  8. Li et al. "Escape Sky-high Cost: Early-stopping Self-Consistency for Multi-step Reasoning" arXiv 2024
  9. Hu et al. "LoRA: Low-Rank Adaptation of Large Language Models" ICLR 2022
  10. Cobbe et al. "Training Verifiers to Solve Math Word Problems" arXiv 2021
  11. Liu et al. "vLLM: Fast Inference for Large Language Models" EMNLP 2023

Citation

@article{brown2025predictive,
  title={Predictive Scheduling for Efficient Inference-Time
         Reasoning in Large Language Models},
  author={Brown, Katrina and Muppidi, Aneesh and Shahout, Rana},
  journal={arXiv preprint arXiv:XXXX.XXXXX},
  year={2025}
}