Search-Based RL · JAX · Combinatorial Optimization

Policy Iteration via Search Distillation

MCTS as a policy improvement operator. For single-agent planning with known dynamics, we use Gumbel MuZero search to compute improved policy targets—softmax(prior + Q)—then distill them into a neural network via cross-entropy. This is policy iteration, but the "improvement" step is search and the "evaluation" step is supervised learning. Pure JAX, competitive with PPO on wall-clock time.

Aneesh Muppidi · February 2026
Policy Iteration via Search Distillation for 3D Bin Packing
Left: The 3D bin packing problem. Fit as many items as possible into a fixed container to maximize volume utilization. Items are placed one at a time; opacity shows placement order (newer = more solid). Right: MCTS explores possible placements before each decision, with node size indicating visit count.
TL;DR

Policy iteration, but the improvement step is Gumbel MuZero search and the learning step is supervised distillation. With JAX's vmap and pmap, we can run search fast enough to compete with PPO on wall-clock time—while achieving 96% volume utilization vs PPO's 90% on 3D bin packing. Closely related to Expert Iteration, but with important differences we make explicit.

Contents

Policy Iteration via Search Distillation

When AlphaZero achieved superhuman performance at Go and Chess, the key innovation wasn't just MCTS or neural networks—it was how the two components reinforced each other. The network provides fast intuition to guide search, and search provides high-quality targets to improve the network. But there's an important detail that's often glossed over: AlphaZero uses self-play because Go and Chess are two-player games. The network plays against itself, generating training data from both sides of the board.

But many important problems don't have an opponent. Consider 3D bin packing: you're given a set of rectangular boxes and a container, and your goal is to pack as many boxes as possible to maximize volume utilization. It sounds like something you could solve by just being clever about placement order, but the combinatorial explosion makes this problem far harder than it appears.

In its 3D form, bin packing is equivalent to the 3D Knapsack Problem—one of Karp's 21 NP-complete problems. (For a thorough treatment of computational complexity and bin packing specifically, Garey and Johnson's Computers and Intractability remains the canonical reference.) The difficulty is navigating the astronomical number of possible placements. With n items and m potential positions, the search space grows roughly as O(mn). Even for modest problem sizes, brute-force search is hopeless.

For problems like this, self-play doesn't apply—there's no "other side" to play. Instead, we run search on each state and use the resulting action distribution as a training target. This is closely related to Expert Iteration (Anthony et al., 2017), which frames the idea as "MCTS is an expert teacher, the network is an apprentice."

But as we'll see, our implementation differs from vanilla Expert Iteration in important ways—most notably, we use Gumbel MuZero's policy improvement operator rather than visit-count targets. So we prefer a more precise framing: policy iteration via search distillation.

The algorithm is best understood as policy iteration—the classic RL framework where you alternate between evaluating your current policy and improving it. But our implementation of each step is unusual:

The network gets better, which makes search stronger (since it uses the network for priors and value estimates), which produces even better targets, and so on. This is the core loop:

┌─────────────────────────────────────────────────────────────────────────────┐
│                                                                             │
│   ┌─────────┐      improved       ┌─────────┐      distill      ┌───────┐   │
│   │ Gumbel  │ ──────────────────> │ Policy  │ ─────────────────>│  NN   │   │
│   │ MuZero  │      targets        │ Targets │      into         │Policy │   │
│   └─────────┘                     └─────────┘                   └───────┘   │
│        ▲                                                            │       │
│        │                          next iteration                    │       │
│        └────────────────────────────────────────────────────────────┘       │
│                                                                             │
│   The network improves -> search becomes stronger -> better targets -> ...     │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
        

Relationship to Expert Iteration

This approach is closely related to Expert Iteration (Anthony et al., 2017), which frames the same idea as "MCTS is an expert teacher, the network is an apprentice." The ExIt paper's Appendix A even notes that AlphaGo Zero "independently developed the EXIT algorithm." So the lineage is clear.

However, there are meaningful differences that make "Policy Iteration via Search Distillation" a more precise description of what we're actually doing:

How This Differs from Vanilla Expert Iteration

1. Policy targets are not visit counts. Traditional ExIt uses n(s,a) / n(s)—the proportion of times MCTS visited each action. We use Gumbel MuZero, which computes softmax(prior_logits + completed_Q). This is a direct policy improvement operator with finite-sample guarantees, not an asymptotic property of visit counts.

2. The "expert" is not independent. In classical ExIt framing, the expert is often conceptualized as a separate, stronger system. Here, the expert is the current network + search—they're coupled. MCTS uses the network's priors to guide search and the network's value estimates to evaluate leaves. This is closer to "compute-amplified self-improvement" than "learning from an external expert."

3. Value targets come from Monte Carlo returns, not MCTS. The value head is trained via supervised regression to empirical returns from rollouts—not from MCTS root values or backed-up estimates. So only the policy is "search-improved"; the value is trained by standard Monte Carlo.

4. Single-agent, no self-play. AlphaZero uses self-play because Go and Chess are two-player games. Our target domain (bin packing) is single-agent—there's no opponent. We simply run search on the current state and distill the result.

None of these differences make the approach invalid. They just mean the ExIt framing is slightly imprecise. "Policy Iteration via Search Distillation" captures the mechanics more accurately: search is the improvement operator, supervised learning is how we internalize the improvement.

Why This Is Not PPO (or Any Policy Gradient Method)

It's worth being explicit about what kind of algorithm this is and isn't. PPO and other policy gradient methods train the policy directly from reward:

\[ \nabla_\theta J(\theta) = \mathbb{E}\left[ \nabla_\theta \log \pi_\theta(a|s) \cdot \hat{A}(s,a) \right] \]

The gradient signal comes from sampled actions and sampled advantages. This is high-variance, which is why PPO needs clipping, GAE, entropy bonuses, and careful hyperparameter tuning.

Our approach is fundamentally different. The policy is trained by imitation:

\[ \mathcal{L}_{\text{policy}}(\theta) = -\sum_{a} \pi_{\text{search}}(a|s) \log \pi_\theta(a|s) \]

This is cross-entropy between the network's output and the search-produced distribution. Reward doesn't appear in this loss at all—it only influences the policy indirectly through how search evaluates actions. The supervision signal is dense (a full distribution over actions) rather than sparse (a single sampled action and its return).

So where does reward come in? Two places:

But the policy gradient is purely imitation of search. In theory, this should allow the learning curves to be more stable compared to PPO, since we're doing supervised learning with high-quality targets, not high-variance policy gradients.

The Algorithm, Precisely

To be completely explicit about what we're doing:

Algorithm: Policy Iteration via Search Distillation
1: Initialize policy-value network πθ, Vθ
2: for iteration = 1, 2, ... do
3:   // Collect experience with search-guided rollouts
4:   for each episode do
5:     for each step do
6:       Run Gumbel MuZero search from current state
7:       Record (state, action_weights, reward)
8:       Sample action from action_weights, step environment
9:   // Compute targets
10:   policy_targets <- action_weights from search
11:   value_targets <- Monte Carlo returns from rewards
12:   // Update network (supervised learning)
13:   θ <- θ - α∇θ[CrossEntropy(πθ, policy_targets) + MSE(Vθ, value_targets)]

The key idea is that search is the policy improvement operator. Given any policy π, running Gumbel MuZero search produces a better policy π' (under mild conditions). Training the network to match π' internalizes that improvement. Repeat.

Comparison to PPO

PPO is an on-policy policy gradient method that maximizes a clipped surrogate objective:

\[ L^{\text{CLIP}}(\theta)=\mathbb{E}_t\left[\min\left(r_t(\theta)\,\hat{A}_t,\; \text{clip}(r_t(\theta),1-\epsilon,1+\epsilon)\,\hat{A}_t\right)\right] \]

where \(r_t(\theta) = \pi_\theta(a_t|s_t) / \pi_{\theta_{\text{old}}}(a_t|s_t)\) is the probability ratio. PPO works well when you can collect lots of trajectories and the environment is too complex for planning. But PPO's gradient signal is based on sampled actions and sampled returns, so variance control (GAE, baselines, clipping) becomes essential.

Search distillation sidesteps this variance problem entirely. Instead of learning from a single sampled action and its noisy return, we learn from the full search distribution, a much richer signal. And because we have exact environment dynamics, search can do genuine lookahead planning, discovering good action sequences that pure policy gradients might take much longer to find.

Aspect PPO Search Distillation
Policy learning signal Sampled action + advantage Full distribution from search
Reward's role in policy Direct (policy gradient) Indirect (shapes search preferences)
Variance High (needs GAE, clipping) Low (supervised learning)
Lookahead None (reactive) Multiple simulations per decision
Requires dynamics No Yes (for search)

Monte Carlo Tree Search with mctx

We use mctx, DeepMind's JAX-native MCTS library. The library is worth studying because it implements batched MCTS that compiles entirely under XLA—the entire search runs as a single compiled program, not Python loops calling into JAX. This is what makes it fast enough to compete with model-free methods on wall-clock time.

The mctx library requires you to provide a "recurrent function" that simulates environment transitions. For problems with known dynamics, this is trivial—we just call the environment's step function:

import mctx

def recurrent_fn(model, rng_key, action, state):
    """Environment model for MCTS simulation."""
    model_params, model_state = model
    
    # Step the environment (we have perfect dynamics!)
    action_pair = unflatten_action(action)
    next_state, timestep = jax.vmap(env.step)(state, action_pair)
    
    # Get network predictions for the next state
    observation = timestep.observation
    (logits, value), _ = forward.apply(
        model_params, model_state, observation, is_eval=True
    )
    
    # Mask invalid actions
    valid_mask = get_valid_action_mask(observation.action_mask)
    logits = apply_action_mask(logits, valid_mask)
    
    return mctx.RecurrentFnOutput(
        reward=timestep.reward,
        discount=timestep.discount,
        prior_logits=logits,
        value=value,
    ), next_state

With this function defined, running MCTS is straightforward. We use gumbel_muzero_policy, which provides a more principled approach to exploration than vanilla MCTS (more on this shortly):

# Run MCTS from the current state
policy_output = mctx.gumbel_muzero_policy(
    params=model,
    rng_key=key,
    root=mctx.RootFnOutput(
        prior_logits=network_logits,
        value=network_value,
        embedding=current_state,  # The state IS the embedding
    ),
    recurrent_fn=recurrent_fn,
    num_simulations=32,          # Search budget per decision
    invalid_actions=~valid_mask,
)

# Extract the improved policy (this is our training target)
mcts_policy = policy_output.action_weights  # Shape: (batch, num_actions)

The action_weights output is the probability distribution over actions that we use as our training target. This is computed from the search tree—in Gumbel MuZero specifically, it's a softmax over the prior logits plus the completed Q-values, which provides a principled policy improvement guarantee.

Under the Hood: How mctx Implements Batched Search

Understanding how mctx works is instructive if you've ever tried to write MCTS and gotten stuck with Python control flow. The core insight is that the entire search loop is implemented using jax.lax.fori_loop and jax.lax.while_loop, which compile to XLA control flow rather than Python loops. Here's the main search loop:

# From mctx/_src/search.py
def body_fun(sim, loop_state):
    rng_key, tree = loop_state
    rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3)

    # Simulate: walk down the tree selecting actions
    parent_index, action = simulate(
        simulate_keys, tree, action_selection_fn, max_depth)

    # Check if we've reached an unexpanded node
    next_node_index = tree.children_index[batch_range, parent_index, action]
    next_node_index = jnp.where(next_node_index == Tree.UNVISITED,
                                sim + 1, next_node_index)

    # Expand: add new node to the tree
    tree = expand(params, expand_key, tree, recurrent_fn, 
                  parent_index, action, next_node_index)
    
    # Backup: propagate values up the tree
    tree = backward(tree, next_node_index)
    return rng_key, tree

The simulation phase uses a while loop to walk down the existing tree until it finds an unvisited edge. At expansion, mctx calls your recurrent function to get the value and prior for the new node. The backup phase then propagates the leaf value up to the root using an incremental mean update:

\[ Q_{\text{parent}} \leftarrow \frac{N_{\text{parent}} \cdot Q_{\text{parent}} + G}{N_{\text{parent}} + 1} \quad\text{where}\quad G = r + \gamma \cdot V_{\text{leaf}} \]

Because all of this is expressed as JAX control flow, the entire search compiles to a single XLA program. Combined with vmap over the batch dimension, this makes it feasible to run dozens of simulations per decision at scale.

Gumbel Exploration

We use Gumbel MuZero rather than vanilla MCTS for a specific reason: when you have many actions and only a small simulation budget, traditional AlphaZero-style exploration can fail to visit enough actions to guarantee improvement. Gumbel MuZero, introduced in "Policy Improvement by Planning with Gumbel" (ICLR 2022), fixes this by sampling actions without replacement in a principled way.

The foundation is the Gumbel-Max trick for sampling from a categorical distribution. If your policy network outputs logits \(\ell(a)\), you can sample from \(\pi(a) = \text{softmax}(\ell)\) by adding Gumbel noise and taking an argmax:

\[ g(a) \sim \text{Gumbel}(0,1), \qquad A = \arg\max_a \left[g(a) + \ell(a)\right] \]

The paper extends this to the Gumbel-Top-k trick, which samples k actions without replacement. At the root node, Gumbel MuZero uses this to select a subset of candidate actions, then allocates the simulation budget across those actions using Sequential Halving—a bandit algorithm optimized for simple regret rather than cumulative regret. This is the right objective when you only care about finding the best action, not about the path you took to find it.

You can see the mechanics directly in the mctx code:

# From mctx/_src/policies.py (gumbel_muzero_policy)

# 1) Mask invalid actions at the root
root = root.replace(
    prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions)
)

# 2) Sample Gumbel noise (same shape as logits)
rng_key, gumbel_rng = jax.random.split(rng_key)
gumbel = gumbel_scale * jax.random.gumbel(
    gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype
)

# 3) Run batched tree search with Gumbel-aware action selection
search_tree = search.search(
    params=params,
    rng_key=rng_key,
    root=root,
    recurrent_fn=recurrent_fn,
    root_action_selection_fn=functools.partial(
        action_selection.gumbel_muzero_root_action_selection,
        num_simulations=num_simulations,
        max_num_considered_actions=max_num_considered_actions,
        qtransform=qtransform,
    ),
    interior_action_selection_fn=functools.partial(
        action_selection.gumbel_muzero_interior_action_selection,
        qtransform=qtransform,
    ),
    num_simulations=num_simulations,
    invalid_actions=invalid_actions,
    extra_data=action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel),
)

One important detail that might surprise you: in Gumbel MuZero, the returned action_weights are not computed from visit counts like in AlphaZero. Instead, they're softmax(prior_logits + completed_qvalues)—a direct implementation of the policy improvement operator from the paper. This theoretical grounding is part of why Gumbel MuZero works well with small simulation budgets.

Gumbel MuZero vs Traditional MCTS

Aspect Traditional MCTS (AlphaZero) Gumbel MuZero
Policy target n(s,a) / n(s) (visit counts) softmax(prior + Q) (policy improvement)
Root exploration UCT + Dirichlet noise Gumbel-Top-k + Sequential Halving
Theoretical guarantee Asymptotic (infinite simulations) Finite-sample policy improvement

The original Expert Iteration paper uses visit count proportions as targets—the idea being that MCTS visits better actions more often. Gumbel MuZero instead computes targets by directly applying a policy improvement operator: it adds the completed Q-values to the prior logits and takes a softmax. This provides stronger guarantees when simulation budgets are small relative to the action space.

The Problem: 3D Bin Packing

What makes 3D bin packing particularly interesting from a reinforcement learning perspective is that it's a perfect-information, deterministic, single-agent planning task. Once you sample a problem instance (a set of items to pack), the environment dynamics are completely known. You place an item, the container state updates deterministically, and you get a reward proportional to the volume you just filled. This structure turns out to be ideal for tree search methods, something we exploit heavily.

Let's formally specify the MDP we're solving. BinPack is an episodic MDP with deterministic dynamics, a large discrete action space, and rewards that are exactly computable from the environment state:

\[ \text{State } s_t = (\text{EMS list},\, \text{items remaining},\, \text{placed items},\, \text{container occupancy}) \] \[ \text{Action } a_t = (\text{ems\_id},\, \text{item\_id}) \in \{0..E-1\}\times\{0..I-1\} \] \[ s_{t+1} = f(s_t, a_t)\quad \text{(exact transition from env.step)} \] \[ r_t = \Delta \text{utilization} \in [0,1] \]

The crucial observation here is that we have access to the exact transition function \(f\). Unlike Atari games where we'd need to learn a world model, or robotics where dynamics are noisy and partially observable, bin packing gives us perfect simulation for free. This means MCTS can roll forward with ground-truth dynamics, and the neural network's job is simply to imitate the search-improved policy. That asymmetry—cheap perfect simulation combined with expensive search—is exactly where search-based policy iteration shines.

The Jumanji Environment

We use InstaDeep's Jumanji library, which provides a JAX-native BinPack environment. The key abstraction that makes this environment tractable is the concept of Empty Maximal Spaces (EMS)—rectangular regions inside the container where items can potentially be placed.

At each step, the agent makes a joint decision: which EMS to place an item in (from up to 40 candidates), and which item to place (from up to 20 items). This gives a joint action space of 40 × 20 = 800 discrete actions. The environment uses dense rewards, meaning each placement adds the item's volume (normalized by container volume) to the cumulative return. A perfect packing that uses all available space yields a return of 1.0.

# Environment setup
import jumanji

env = jumanji.make("BinPack-v2")

# Action space dimensions
obs_num_ems = 40   # Observable empty maximal spaces
max_num_items = 20  # Maximum items per episode
num_actions = obs_num_ems * max_num_items  # 800 total

# We flatten the action for simpler MCTS handling
def unflatten_action(action):
    """Flat index → (ems_id, item_id)"""
    ems_id = action // max_num_items
    item_id = action % max_num_items
    return jnp.stack([ems_id, item_id], axis=-1)

The environment's observation structure is designed to be "planning-friendly"—rather than giving raw pixels or low-level state, it provides two sets of tokens (EMS and items) plus boolean masks defining what's valid. This aligns nicely with transformer-style encoders and with our flattened (E×I) action head. Each EMS is represented by its 6D bounding box coordinates (x1, x2, y1, y2, z1, z2), and each item by its three dimensions (x_len, y_len, z_len). The critical piece is the action_mask: a boolean matrix where entry (e, i) is True if item i can legally be placed in EMS e.

How does the environment compute this joint action mask? For each EMS and each item, it tests whether the item fits in that EMS and whether the item is still available. This is implemented with nested jax.vmap, so the entire mask computation happens in parallel:

# From jumanji/environments/packing/bin_pack/env.py
def is_action_allowed(ems, ems_mask, item, item_mask, item_placed):
    item_fits_in_ems = item_fits_in_item(item, item_from_space(ems))
    return ~item_placed & item_mask & ems_mask & item_fits_in_ems

action_mask = jax.vmap(
    jax.vmap(is_action_allowed, in_axes=(None, None, 0, 0, 0)),
    in_axes=(0, 0, None, None, None),
)(obs_ems, obs_ems_mask, items, items_mask, items_placed)

One subtle detail: the full environment can track more than 40 EMS internally, but the observation only returns the obs_num_ems largest by volume. This keeps the observation size fixed, turning a variable-structure geometric process into a fixed-shape tensor problem—exactly what we need for efficient JAX/XLA compilation. The environment is also strict about invalid actions: if you choose an action where the item doesn't fit or has already been placed, the episode terminates immediately. This makes action masking a first-class concern, both in PPO and especially in MCTS where we want zero probability mass on invalid moves.

The EMS representation is a classic way to represent free space as a non-disjoint set of maximal empty cuboids. If you're interested in the "pre-RL" perspective on this representation, Parreño et al.'s technical report on GRASP heuristics and Zhao et al.'s comparative review of 3D container loading provide excellent background. The key idea is that EMS makes the problem look like "choose a space + choose an item", which is exactly how our policy and value networks will model it.

JAX Parallelism

The historical challenge with search-based policy iteration is computational cost. Running 32 MCTS simulations per action adds significant overhead compared to a single network forward pass. But JAX's parallelization primitives change the calculus dramatically.

The key primitives are jax.vmap for vectorizing over the batch dimension and jax.pmap for distributing across devices. With vmap, a single line of code turns a function that processes one environment into a function that processes thousands in parallel:

# Without vmap: process episodes one at a time
for i in range(batch_size):
    state, timestep = env.step(states[i], actions[i])

# With vmap: all episodes processed in parallel
state, timestep = jax.vmap(env.step)(states, actions)

The pmap primitive replicates computation across all available GPUs or TPUs. Each device processes a shard of the batch independently, and gradients are synchronized using jax.lax.pmean:

@partial(jax.pmap, axis_name="devices")
def train_step(model, opt_state, batch):
    # Compute gradients on this device's shard
    grads = jax.grad(loss_fn)(model, batch)
    
    # Synchronize gradients across all devices
    grads = jax.lax.pmean(grads, axis_name="devices")
    
    # Apply optimizer update
    updates, opt_state = optimizer.update(grads, opt_state)
    model = optax.apply_updates(model, updates)
    return model, opt_state

Together, vmap and pmap give us massive throughput. On 4 GPUs with batch size 1024, we process roughly 20,000 MCTS-guided decisions per second—competitive with PPO's sample efficiency despite running 32 simulations per decision. The JAX documentation on parallelism provides an excellent introduction to these concepts.

Neural Network Architecture

We adapt Jumanji's A2C architecture for our policy-value network. The key idea is using cross-attention between EMS tokens and item tokens. This lets the network reason about which items fit in which spaces, with the action mask gating the attention to only consider valid placement combinations.

The architecture processes EMS and item tokens through alternating layers of self-attention (within each token type) and cross-attention (between types). After encoding, the policy head computes logits via a bilinear form—logits[e,i] = ems_h[e] · items_h[i]—which naturally produces the (E × I) shaped output we need. The value head pools the embeddings and predicts expected utilization in [0, 1].

# Architecture overview
#
#   ┌─────────────┐     ┌─────────────┐
#   │  EMS tokens │     │ Item tokens │
#   │ (40 × 6D)   │     │ (20 × 3D)   │
#   └──────┬──────┘     └──────┬──────┘
#          │                   │
#          ▼                   ▼
#   ┌─────────────┐     ┌─────────────┐
#   │ Self-Attn   │     │ Self-Attn   │
#   └──────┬──────┘     └──────┬──────┘
#          │                   │
#          └────────┬──────────┘
#                   ▼
#          ┌───────────────┐
#          │ Cross-Attn    │  <- EMS ↔ Items interaction
#          │ (bidirectional)│    (gated by action_mask)
#          └───────┬───────┘
#                  │
#          ┌───────┴───────┐
#          ▼               ▼
#   ┌─────────────┐  ┌─────────────┐
#   │ Policy Head │  │ Value Head  │
#   │ (E×I logits)│  │ (scalar)    │
#   └─────────────┘  └─────────────┘

We use Haiku to build the network. Haiku's hk.transform_with_state pattern cleanly separates the stateful module definition from the pure functional interface that JAX requires:

import haiku as hk

def forward_fn(observation, is_eval=False):
    net = BinPackPolicyValueNet(
        num_transformer_layers=4,
        transformer_num_heads=4,
        transformer_key_size=32,
        transformer_mlp_units=(256, 256),
    )
    return net(observation, is_training=not is_eval)

# Transform to pure functions
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))

# Initialize parameters
params, state = forward.init(rng_key, dummy_obs)

# Apply is now a pure function
(logits, value), new_state = forward.apply(params, state, observation)

Experiments

We trained both search distillation and PPO on BinPack-v2 for 800 iterations across 5 random seeds. To ensure a fair comparison, both methods use the same network architecture, batch sizes, and evaluation protocol. The key hyperparameters are:

Hyperparameter Search Distillation PPO
Rollout batch size 1024 1024
Training batch size 4096 4096
MCTS simulations 32
PPO epochs per iteration 4
Learning rate 1e-3 3e-4
Transformer layers 4 4
PPO vs Search Distillation learning curves
Figure 1. Learning curves comparing search distillation (light blue) vs PPO (dark blue) on BinPack-v2. Shaded regions show ±1 standard error across 5 seeds. Search distillation converges to 96% volume utilization while PPO plateaus around 90%.

The results are clear: search distillation achieves 96% volume utilization compared to PPO's 90%, a 6 percentage point improvement that translates to better packing efficiency in practice. Despite running 32 MCTS simulations per action, JAX parallelism keeps training time competitive—both methods reach convergence in roughly the same wall-clock time.

Implementation Details

A few implementation details are worth highlighting, as they address common pitfalls when working with JAX and MCTS.

Using jax.lax.scan for Efficient Loops

Python loops inside JIT-compiled functions are unrolled at compile time, causing slow compilation and memory issues for long episodes. The solution is jax.lax.scan, which compiles the loop body once and executes it repeatedly:

# Inefficient: Python loop gets unrolled
def rollout_bad(state, keys):
    trajectory = []
    for key in keys:
        state, data = step(state, key)
        trajectory.append(data)
    return trajectory

# Efficient: scan compiles once, runs fast
def rollout_good(state, keys):
    def step_fn(carry, key):
        state = carry
        state, data = step(state, key)
        return state, data
    
    final_state, trajectory = jax.lax.scan(step_fn, state, keys)
    return trajectory

Action Masking Without NaNs

Invalid actions must be masked before computing the softmax for action selection. A common mistake is using -inf for invalid logits, which causes NaN when all actions are invalid (as can happen at episode termination). The solution is to use a large but finite negative value:

def apply_action_mask(logits, valid_mask):
    """Set invalid action logits to a large negative value."""
    # Center for numerical stability
    logits = logits - jnp.max(logits, axis=-1, keepdims=True)
    # Use finite minimum (not -inf) to avoid NaN
    return jnp.where(valid_mask, logits, jnp.finfo(logits.dtype).min)

The mctx library uses the same approach for the same reason. At the end of an episode, all actions can become invalid, and the code needs to handle this gracefully.

When to Use Search Distillation

Search-based policy iteration is particularly well-suited to problems where you have exact dynamics (so MCTS can simulate accurately), episodes are relatively short (so MCTS overhead doesn't compound excessively), actions have complex dependencies that gradient signals might miss, and you have parallel compute to amortize the cost of search. Bin packing hits all four criteria.

PPO remains preferable when dynamics must be learned rather than simulated, horizons are very long (hundreds of steps), action spaces are simple with clear reward gradients, or when you need the simplicity of a model-free approach.

References

Code

The complete implementation is available on GitHub. Key dependencies include mctx for MCTS, Jumanji for the BinPack environment, Haiku for neural networks, and Optax for optimization. The PGX AlphaZero example was a helpful reference for JAX-native AlphaZero patterns.

Papers

  1. Anthony et al., "Thinking Fast and Slow with Deep Learning and Tree Search", NeurIPS 2017. The original Expert Iteration paper—closely related to our approach.
  2. Danihelka et al., "Policy Improvement by Planning with Gumbel", ICLR 2022. The theoretical foundation for Gumbel MuZero and our policy improvement operator.
  3. Grill et al., "Monte-Carlo Tree Search as Regularized Policy Optimization", ICML 2020. Interprets MCTS as approximately solving a regularized policy optimization problem.
  4. Silver et al., "Mastering the game of Go without human knowledge", Nature 2017. The AlphaZero paper.
  5. Schrittwieser et al., "Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model", Nature 2020. The MuZero paper.
  6. Schulman et al., "Proximal Policy Optimization Algorithms", arXiv 2017. The PPO paper.
  7. Bonnet et al., "Jumanji: a Diverse Suite of Scalable Reinforcement Learning Environments in JAX", 2023. The Jumanji environment suite including BinPack.
  8. Karnin, Koren, Somekh, "Almost Optimal Exploration in Multi-Armed Bandits", ICML 2013. The Sequential Halving algorithm used in Gumbel MuZero.

Citation

@misc{muppidi2026searchdistill,
  title={Policy Iteration via Search Distillation for 3D Bin Packing},
  author={Muppidi, Aneesh},
  year={2026},
  url={https://github.com/Aneeshers/policy_search_distillation}
}