MCTS as a policy improvement operator. For single-agent planning with known dynamics, we use Gumbel MuZero search to compute improved policy targets—softmax(prior + Q)—then distill them
into a neural network via cross-entropy. This is policy iteration, but the "improvement" step is search and the "evaluation" step is supervised learning. Pure JAX, competitive with PPO on wall-clock time.
Policy iteration, but the improvement step is Gumbel MuZero search and the learning step is supervised distillation. With JAX's vmap and pmap, we can run search fast enough
to compete with PPO on wall-clock time—while achieving 96% volume utilization vs PPO's 90% on 3D bin packing. Closely related to Expert Iteration, but with important differences we make explicit.
When AlphaZero achieved superhuman performance at Go and Chess, the key innovation wasn't just MCTS or neural networks—it was how the two components reinforced each other. The network provides fast intuition to guide search, and search provides high-quality targets to improve the network. But there's an important detail that's often glossed over: AlphaZero uses self-play because Go and Chess are two-player games. The network plays against itself, generating training data from both sides of the board.
But many important problems don't have an opponent. Consider 3D bin packing: you're given a set of rectangular boxes and a container, and your goal is to pack as many boxes as possible to maximize volume utilization. It sounds like something you could solve by just being clever about placement order, but the combinatorial explosion makes this problem far harder than it appears.
In its 3D form, bin packing is equivalent to the 3D Knapsack Problem—one of Karp's 21 NP-complete problems. (For a thorough treatment of computational complexity and bin packing specifically, Garey and Johnson's Computers and Intractability remains the canonical reference.) The difficulty is navigating the astronomical number of possible placements. With n items and m potential positions, the search space grows roughly as O(mn). Even for modest problem sizes, brute-force search is hopeless.
For problems like this, self-play doesn't apply—there's no "other side" to play. Instead, we run search on each state and use the resulting action distribution as a training target. This is closely related to Expert Iteration (Anthony et al., 2017), which frames the idea as "MCTS is an expert teacher, the network is an apprentice."
But as we'll see, our implementation differs from vanilla Expert Iteration in important ways—most notably, we use Gumbel MuZero's policy improvement operator rather than visit-count targets. So we prefer a more precise framing: policy iteration via search distillation.
The algorithm is best understood as policy iteration—the classic RL framework where you alternate between evaluating your current policy and improving it. But our implementation of each step is unusual:
softmax(prior + Q) as the improved action distributionThe network gets better, which makes search stronger (since it uses the network for priors and value estimates), which produces even better targets, and so on. This is the core loop:
┌─────────────────────────────────────────────────────────────────────────────┐
│ │
│ ┌─────────┐ improved ┌─────────┐ distill ┌───────┐ │
│ │ Gumbel │ ──────────────────> │ Policy │ ─────────────────>│ NN │ │
│ │ MuZero │ targets │ Targets │ into │Policy │ │
│ └─────────┘ └─────────┘ └───────┘ │
│ ▲ │ │
│ │ next iteration │ │
│ └────────────────────────────────────────────────────────────┘ │
│ │
│ The network improves -> search becomes stronger -> better targets -> ... │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
This approach is closely related to Expert Iteration (Anthony et al., 2017), which frames the same idea as "MCTS is an expert teacher, the network is an apprentice." The ExIt paper's Appendix A even notes that AlphaGo Zero "independently developed the EXIT algorithm." So the lineage is clear.
However, there are meaningful differences that make "Policy Iteration via Search Distillation" a more precise description of what we're actually doing:
1. Policy targets are not visit counts. Traditional ExIt uses n(s,a) / n(s)—the proportion of times MCTS visited each action. We use Gumbel MuZero, which computes softmax(prior_logits + completed_Q).
This is a direct policy improvement operator with finite-sample guarantees, not an asymptotic property of visit counts.
2. The "expert" is not independent. In classical ExIt framing, the expert is often conceptualized as a separate, stronger system. Here, the expert is the current network + search—they're coupled. MCTS uses the network's priors to guide search and the network's value estimates to evaluate leaves. This is closer to "compute-amplified self-improvement" than "learning from an external expert."
3. Value targets come from Monte Carlo returns, not MCTS. The value head is trained via supervised regression to empirical returns from rollouts—not from MCTS root values or backed-up estimates. So only the policy is "search-improved"; the value is trained by standard Monte Carlo.
4. Single-agent, no self-play. AlphaZero uses self-play because Go and Chess are two-player games. Our target domain (bin packing) is single-agent—there's no opponent. We simply run search on the current state and distill the result.
None of these differences make the approach invalid. They just mean the ExIt framing is slightly imprecise. "Policy Iteration via Search Distillation" captures the mechanics more accurately: search is the improvement operator, supervised learning is how we internalize the improvement.
It's worth being explicit about what kind of algorithm this is and isn't. PPO and other policy gradient methods train the policy directly from reward:
The gradient signal comes from sampled actions and sampled advantages. This is high-variance, which is why PPO needs clipping, GAE, entropy bonuses, and careful hyperparameter tuning.
Our approach is fundamentally different. The policy is trained by imitation:
This is cross-entropy between the network's output and the search-produced distribution. Reward doesn't appear in this loss at all—it only influences the policy indirectly through how search evaluates actions. The supervision signal is dense (a full distribution over actions) rather than sparse (a single sampled action and its return).
So where does reward come in? Two places:
reward + discount * value to evaluate nodes. Reward shapes which actions search prefers.But the policy gradient is purely imitation of search. In theory, this should allow the learning curves to be more stable compared to PPO, since we're doing supervised learning with high-quality targets, not high-variance policy gradients.
To be completely explicit about what we're doing:
The key idea is that search is the policy improvement operator. Given any policy π, running Gumbel MuZero search produces a better policy π' (under mild conditions). Training the network to match π' internalizes that improvement. Repeat.
PPO is an on-policy policy gradient method that maximizes a clipped surrogate objective:
where \(r_t(\theta) = \pi_\theta(a_t|s_t) / \pi_{\theta_{\text{old}}}(a_t|s_t)\) is the probability ratio. PPO works well when you can collect lots of trajectories and the environment is too complex for planning. But PPO's gradient signal is based on sampled actions and sampled returns, so variance control (GAE, baselines, clipping) becomes essential.
Search distillation sidesteps this variance problem entirely. Instead of learning from a single sampled action and its noisy return, we learn from the full search distribution, a much richer signal. And because we have exact environment dynamics, search can do genuine lookahead planning, discovering good action sequences that pure policy gradients might take much longer to find.
| Aspect | PPO | Search Distillation |
|---|---|---|
| Policy learning signal | Sampled action + advantage | Full distribution from search |
| Reward's role in policy | Direct (policy gradient) | Indirect (shapes search preferences) |
| Variance | High (needs GAE, clipping) | Low (supervised learning) |
| Lookahead | None (reactive) | Multiple simulations per decision |
| Requires dynamics | No | Yes (for search) |
We use mctx, DeepMind's JAX-native MCTS library. The library is worth studying because it implements batched MCTS that compiles entirely under XLA—the entire search runs as a single compiled program, not Python loops calling into JAX. This is what makes it fast enough to compete with model-free methods on wall-clock time.
The mctx library requires you to provide a "recurrent function" that simulates environment transitions. For problems with known dynamics, this is trivial—we just call the environment's step function:
import mctx
def recurrent_fn(model, rng_key, action, state):
"""Environment model for MCTS simulation."""
model_params, model_state = model
# Step the environment (we have perfect dynamics!)
action_pair = unflatten_action(action)
next_state, timestep = jax.vmap(env.step)(state, action_pair)
# Get network predictions for the next state
observation = timestep.observation
(logits, value), _ = forward.apply(
model_params, model_state, observation, is_eval=True
)
# Mask invalid actions
valid_mask = get_valid_action_mask(observation.action_mask)
logits = apply_action_mask(logits, valid_mask)
return mctx.RecurrentFnOutput(
reward=timestep.reward,
discount=timestep.discount,
prior_logits=logits,
value=value,
), next_state
With this function defined, running MCTS is straightforward. We use gumbel_muzero_policy, which provides a more principled approach to exploration than vanilla MCTS (more on this shortly):
# Run MCTS from the current state
policy_output = mctx.gumbel_muzero_policy(
params=model,
rng_key=key,
root=mctx.RootFnOutput(
prior_logits=network_logits,
value=network_value,
embedding=current_state, # The state IS the embedding
),
recurrent_fn=recurrent_fn,
num_simulations=32, # Search budget per decision
invalid_actions=~valid_mask,
)
# Extract the improved policy (this is our training target)
mcts_policy = policy_output.action_weights # Shape: (batch, num_actions)
The action_weights output is the probability distribution over actions that we use as our training target. This is computed from the search tree—in Gumbel MuZero specifically, it's a softmax over the
prior logits plus the completed Q-values, which provides a principled policy improvement guarantee.
Understanding how mctx works is instructive if you've ever tried to write MCTS and gotten stuck with Python control flow. The core insight is that the entire search loop is implemented using jax.lax.fori_loop and jax.lax.while_loop, which compile to XLA control flow rather than Python loops. Here's the main search loop:
# From mctx/_src/search.py
def body_fun(sim, loop_state):
rng_key, tree = loop_state
rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3)
# Simulate: walk down the tree selecting actions
parent_index, action = simulate(
simulate_keys, tree, action_selection_fn, max_depth)
# Check if we've reached an unexpanded node
next_node_index = tree.children_index[batch_range, parent_index, action]
next_node_index = jnp.where(next_node_index == Tree.UNVISITED,
sim + 1, next_node_index)
# Expand: add new node to the tree
tree = expand(params, expand_key, tree, recurrent_fn,
parent_index, action, next_node_index)
# Backup: propagate values up the tree
tree = backward(tree, next_node_index)
return rng_key, tree
The simulation phase uses a while loop to walk down the existing tree until it finds an unvisited edge. At expansion, mctx calls your recurrent function to get the value and prior for the new node. The backup phase then propagates the leaf value up to the root using an incremental mean update:
Because all of this is expressed as JAX control flow, the entire search compiles to a single XLA program. Combined with vmap over the batch dimension, this makes it feasible to run dozens of simulations
per decision at scale.
We use Gumbel MuZero rather than vanilla MCTS for a specific reason: when you have many actions and only a small simulation budget, traditional AlphaZero-style exploration can fail to visit enough actions to guarantee improvement. Gumbel MuZero, introduced in "Policy Improvement by Planning with Gumbel" (ICLR 2022), fixes this by sampling actions without replacement in a principled way.
The foundation is the Gumbel-Max trick for sampling from a categorical distribution. If your policy network outputs logits \(\ell(a)\), you can sample from \(\pi(a) = \text{softmax}(\ell)\) by adding Gumbel noise and taking an argmax:
The paper extends this to the Gumbel-Top-k trick, which samples k actions without replacement. At the root node, Gumbel MuZero uses this to select a subset of candidate actions, then allocates the simulation budget across those actions using Sequential Halving—a bandit algorithm optimized for simple regret rather than cumulative regret. This is the right objective when you only care about finding the best action, not about the path you took to find it.
You can see the mechanics directly in the mctx code:
# From mctx/_src/policies.py (gumbel_muzero_policy)
# 1) Mask invalid actions at the root
root = root.replace(
prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions)
)
# 2) Sample Gumbel noise (same shape as logits)
rng_key, gumbel_rng = jax.random.split(rng_key)
gumbel = gumbel_scale * jax.random.gumbel(
gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype
)
# 3) Run batched tree search with Gumbel-aware action selection
search_tree = search.search(
params=params,
rng_key=rng_key,
root=root,
recurrent_fn=recurrent_fn,
root_action_selection_fn=functools.partial(
action_selection.gumbel_muzero_root_action_selection,
num_simulations=num_simulations,
max_num_considered_actions=max_num_considered_actions,
qtransform=qtransform,
),
interior_action_selection_fn=functools.partial(
action_selection.gumbel_muzero_interior_action_selection,
qtransform=qtransform,
),
num_simulations=num_simulations,
invalid_actions=invalid_actions,
extra_data=action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel),
)
One important detail that might surprise you: in Gumbel MuZero, the returned action_weights are not computed from visit counts like in AlphaZero. Instead, they're softmax(prior_logits + completed_qvalues)—a
direct implementation of the policy improvement operator from the paper. This theoretical grounding is part of why Gumbel MuZero works well with small simulation budgets.
| Aspect | Traditional MCTS (AlphaZero) | Gumbel MuZero |
|---|---|---|
| Policy target | n(s,a) / n(s) (visit counts) |
softmax(prior + Q) (policy improvement) |
| Root exploration | UCT + Dirichlet noise | Gumbel-Top-k + Sequential Halving |
| Theoretical guarantee | Asymptotic (infinite simulations) | Finite-sample policy improvement |
The original Expert Iteration paper uses visit count proportions as targets—the idea being that MCTS visits better actions more often. Gumbel MuZero instead computes targets by directly applying a policy improvement operator: it adds the completed Q-values to the prior logits and takes a softmax. This provides stronger guarantees when simulation budgets are small relative to the action space.
What makes 3D bin packing particularly interesting from a reinforcement learning perspective is that it's a perfect-information, deterministic, single-agent planning task. Once you sample a problem instance (a set of items to pack), the environment dynamics are completely known. You place an item, the container state updates deterministically, and you get a reward proportional to the volume you just filled. This structure turns out to be ideal for tree search methods, something we exploit heavily.
Let's formally specify the MDP we're solving. BinPack is an episodic MDP with deterministic dynamics, a large discrete action space, and rewards that are exactly computable from the environment state:
The crucial observation here is that we have access to the exact transition function \(f\). Unlike Atari games where we'd need to learn a world model, or robotics where dynamics are noisy and partially observable, bin packing gives us perfect simulation for free. This means MCTS can roll forward with ground-truth dynamics, and the neural network's job is simply to imitate the search-improved policy. That asymmetry—cheap perfect simulation combined with expensive search—is exactly where search-based policy iteration shines.
We use InstaDeep's Jumanji library, which provides a JAX-native BinPack environment. The key abstraction that makes this environment tractable is the concept of Empty Maximal Spaces (EMS)—rectangular regions inside the container where items can potentially be placed.
At each step, the agent makes a joint decision: which EMS to place an item in (from up to 40 candidates), and which item to place (from up to 20 items). This gives a joint action space of 40 × 20 = 800 discrete actions. The environment uses dense rewards, meaning each placement adds the item's volume (normalized by container volume) to the cumulative return. A perfect packing that uses all available space yields a return of 1.0.
# Environment setup
import jumanji
env = jumanji.make("BinPack-v2")
# Action space dimensions
obs_num_ems = 40 # Observable empty maximal spaces
max_num_items = 20 # Maximum items per episode
num_actions = obs_num_ems * max_num_items # 800 total
# We flatten the action for simpler MCTS handling
def unflatten_action(action):
"""Flat index → (ems_id, item_id)"""
ems_id = action // max_num_items
item_id = action % max_num_items
return jnp.stack([ems_id, item_id], axis=-1)
The environment's observation structure is designed to be "planning-friendly"—rather than giving raw pixels or low-level state, it provides two sets of tokens (EMS and items) plus boolean masks defining what's valid. This aligns nicely with transformer-style
encoders and with our flattened (E×I) action head. Each EMS is represented by its 6D bounding box coordinates (x1, x2, y1, y2, z1, z2), and each item by its three dimensions (x_len, y_len, z_len). The critical piece is the
action_mask: a boolean matrix where entry (e, i) is True if item i can legally be placed in EMS e.
How does the environment compute this joint action mask? For each EMS and each item, it tests whether the item fits in that EMS and whether the item is still available. This is implemented with nested jax.vmap,
so the entire mask computation happens in parallel:
# From jumanji/environments/packing/bin_pack/env.py
def is_action_allowed(ems, ems_mask, item, item_mask, item_placed):
item_fits_in_ems = item_fits_in_item(item, item_from_space(ems))
return ~item_placed & item_mask & ems_mask & item_fits_in_ems
action_mask = jax.vmap(
jax.vmap(is_action_allowed, in_axes=(None, None, 0, 0, 0)),
in_axes=(0, 0, None, None, None),
)(obs_ems, obs_ems_mask, items, items_mask, items_placed)
One subtle detail: the full environment can track more than 40 EMS internally, but the observation only returns the obs_num_ems largest by volume. This keeps the observation size fixed, turning a variable-structure
geometric process into a fixed-shape tensor problem—exactly what we need for efficient JAX/XLA compilation. The environment is also strict about invalid actions: if you choose an action where the item doesn't fit or has already been
placed, the episode terminates immediately. This makes action masking a first-class concern, both in PPO and especially in MCTS where we want zero probability mass on invalid moves.
The EMS representation is a classic way to represent free space as a non-disjoint set of maximal empty cuboids. If you're interested in the "pre-RL" perspective on this representation, Parreño et al.'s technical report on GRASP heuristics and Zhao et al.'s comparative review of 3D container loading provide excellent background. The key idea is that EMS makes the problem look like "choose a space + choose an item", which is exactly how our policy and value networks will model it.
The historical challenge with search-based policy iteration is computational cost. Running 32 MCTS simulations per action adds significant overhead compared to a single network forward pass. But JAX's parallelization primitives change the calculus dramatically.
The key primitives are jax.vmap for vectorizing over the batch dimension and jax.pmap for distributing across devices. With vmap, a
single line of code turns a function that processes one environment into a function that processes thousands in parallel:
# Without vmap: process episodes one at a time
for i in range(batch_size):
state, timestep = env.step(states[i], actions[i])
# With vmap: all episodes processed in parallel
state, timestep = jax.vmap(env.step)(states, actions)
The pmap primitive replicates computation across all available GPUs or TPUs. Each device processes a shard of the batch independently, and gradients are synchronized using jax.lax.pmean:
@partial(jax.pmap, axis_name="devices")
def train_step(model, opt_state, batch):
# Compute gradients on this device's shard
grads = jax.grad(loss_fn)(model, batch)
# Synchronize gradients across all devices
grads = jax.lax.pmean(grads, axis_name="devices")
# Apply optimizer update
updates, opt_state = optimizer.update(grads, opt_state)
model = optax.apply_updates(model, updates)
return model, opt_state
Together, vmap and pmap give us massive throughput. On 4 GPUs with batch size 1024, we process roughly 20,000 MCTS-guided decisions per second—competitive with PPO's sample efficiency despite running 32 simulations per decision. The JAX documentation on parallelism provides an excellent introduction to these concepts.
We adapt Jumanji's A2C architecture for our policy-value network. The key idea is using cross-attention between EMS tokens and item tokens. This lets the network reason about which items fit in which spaces, with the action mask gating the attention to only consider valid placement combinations.
The architecture processes EMS and item tokens through alternating layers of self-attention (within each token type) and cross-attention (between types). After encoding, the policy head computes logits via a bilinear form—logits[e,i] = ems_h[e] · items_h[i]—which
naturally produces the (E × I) shaped output we need. The value head pools the embeddings and predicts expected utilization in [0, 1].
# Architecture overview
#
# ┌─────────────┐ ┌─────────────┐
# │ EMS tokens │ │ Item tokens │
# │ (40 × 6D) │ │ (20 × 3D) │
# └──────┬──────┘ └──────┬──────┘
# │ │
# ▼ ▼
# ┌─────────────┐ ┌─────────────┐
# │ Self-Attn │ │ Self-Attn │
# └──────┬──────┘ └──────┬──────┘
# │ │
# └────────┬──────────┘
# ▼
# ┌───────────────┐
# │ Cross-Attn │ <- EMS ↔ Items interaction
# │ (bidirectional)│ (gated by action_mask)
# └───────┬───────┘
# │
# ┌───────┴───────┐
# ▼ ▼
# ┌─────────────┐ ┌─────────────┐
# │ Policy Head │ │ Value Head │
# │ (E×I logits)│ │ (scalar) │
# └─────────────┘ └─────────────┘
We use Haiku to build the network. Haiku's hk.transform_with_state pattern cleanly separates the stateful module definition from the pure
functional interface that JAX requires:
import haiku as hk
def forward_fn(observation, is_eval=False):
net = BinPackPolicyValueNet(
num_transformer_layers=4,
transformer_num_heads=4,
transformer_key_size=32,
transformer_mlp_units=(256, 256),
)
return net(observation, is_training=not is_eval)
# Transform to pure functions
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
# Initialize parameters
params, state = forward.init(rng_key, dummy_obs)
# Apply is now a pure function
(logits, value), new_state = forward.apply(params, state, observation)
We trained both search distillation and PPO on BinPack-v2 for 800 iterations across 5 random seeds. To ensure a fair comparison, both methods use the same network architecture, batch sizes, and evaluation protocol. The key hyperparameters are:
| Hyperparameter | Search Distillation | PPO |
|---|---|---|
| Rollout batch size | 1024 | 1024 |
| Training batch size | 4096 | 4096 |
| MCTS simulations | 32 | — |
| PPO epochs per iteration | — | 4 |
| Learning rate | 1e-3 | 3e-4 |
| Transformer layers | 4 | 4 |
The results are clear: search distillation achieves 96% volume utilization compared to PPO's 90%, a 6 percentage point improvement that translates to better packing efficiency in practice. Despite running 32 MCTS simulations per action, JAX parallelism keeps training time competitive—both methods reach convergence in roughly the same wall-clock time.
A few implementation details are worth highlighting, as they address common pitfalls when working with JAX and MCTS.
Python loops inside JIT-compiled functions are unrolled at compile time, causing slow compilation and memory issues for long episodes. The solution is jax.lax.scan, which compiles the loop body once
and executes it repeatedly:
# Inefficient: Python loop gets unrolled
def rollout_bad(state, keys):
trajectory = []
for key in keys:
state, data = step(state, key)
trajectory.append(data)
return trajectory
# Efficient: scan compiles once, runs fast
def rollout_good(state, keys):
def step_fn(carry, key):
state = carry
state, data = step(state, key)
return state, data
final_state, trajectory = jax.lax.scan(step_fn, state, keys)
return trajectory
Invalid actions must be masked before computing the softmax for action selection. A common mistake is using -inf for invalid logits, which causes NaN when all actions are invalid (as can happen at episode
termination). The solution is to use a large but finite negative value:
def apply_action_mask(logits, valid_mask):
"""Set invalid action logits to a large negative value."""
# Center for numerical stability
logits = logits - jnp.max(logits, axis=-1, keepdims=True)
# Use finite minimum (not -inf) to avoid NaN
return jnp.where(valid_mask, logits, jnp.finfo(logits.dtype).min)
The mctx library uses the same approach for the same reason. At the end of an episode, all actions can become invalid, and the code needs to handle this gracefully.
Search-based policy iteration is particularly well-suited to problems where you have exact dynamics (so MCTS can simulate accurately), episodes are relatively short (so MCTS overhead doesn't compound excessively), actions have complex dependencies that gradient signals might miss, and you have parallel compute to amortize the cost of search. Bin packing hits all four criteria.
PPO remains preferable when dynamics must be learned rather than simulated, horizons are very long (hundreds of steps), action spaces are simple with clear reward gradients, or when you need the simplicity of a model-free approach.
The complete implementation is available on GitHub. Key dependencies include mctx for MCTS, Jumanji for the BinPack environment, Haiku for neural networks, and Optax for optimization. The PGX AlphaZero example was a helpful reference for JAX-native AlphaZero patterns.
@misc{muppidi2026searchdistill,
title={Policy Iteration via Search Distillation for 3D Bin Packing},
author={Muppidi, Aneesh},
year={2026},
url={https://github.com/Aneeshers/policy_search_distillation}
}