Technical Report · 2025

AlphaZero for 3D Bin Packing: MCTS with Learned Policy-Value Networks

Aneesh Muppidi1
1Harvard University
Can we combine Monte Carlo Tree Search with neural network guidance to solve combinatorial packing problems that defeat standard RL approaches?
TL;DR

We apply AlphaZero-style MCTS to 3D bin packing using JAX, the mctx library, and a custom Transformer architecture. Our approach achieves 0.96 average utilization compared to 0.90 from PPO—a 6 percentage point improvement at identical network capacity, demonstrating the power of search-based planning for combinatorial optimization.

The 3D Bin Packing Problem

The 3D bin packing problem is a classic NP-hard combinatorial optimization challenge: given a set of rectangular items with dimensions (x, y, z) and a container of fixed size, pack as many items as possible to maximize space utilization. Unlike the 1D knapsack where we only decide which items to include, 3D packing requires deciding where to place each item—and placement affects future available spaces.

Jumanji's BinPack-v2 environment formulates this as a sequential decision process. At each timestep, the agent selects:

  1. An Empty Maximal Space (EMS)—a maximal empty rectangular region in the container
  2. An item from the remaining unpacked items

The action space is the Cartesian product: A = E × I, where E = obs_num_ems = 40 and I = max_num_items = 20. This yields 800 possible actions per step.

State Representation

Observation Space
Field Shape Description
ems (E, 6) EMS coordinates (x1, x2, y1, y2, z1, z2)
ems_mask (E,) Valid EMS slots
items (I, 3) Item dimensions (x_len, y_len, z_len)
items_mask (I,) Valid item slots
items_placed (I,) Already-placed items
action_mask (E, I) Feasible (EMS, item) pairs

The dense reward equals the volume utilization gained at each step. Episode returns sum to total utilization in [0, 1].

Monte Carlo Tree Search with Neural Guidance

Pure MCTS struggles with large action spaces. AlphaZero's insight: use a neural network to provide prior probabilities over actions (guiding exploration) and value estimates (replacing random rollouts).

AlphaZero Search (per move)
1for sim = 1 to N do
2  Selection: Traverse tree via UCB scores
3  Expansion: At leaf, evaluate neural net → (π, v)
4  Backup: Propagate v up the tree
5end for
6return action_weights ∝ visit_counts

The neural network is trained to match the search policy (visit distribution) and value (Monte Carlo returns), creating a virtuous cycle: better networks → better search → better training targets → better networks.

UCB(s, a) = Q(s, a) + c · π(a|s) · √N(s) / (1 + N(s, a))

The mctx Library: Gumbel MuZero

We use DeepMind's mctx library, which provides JAX-native MCTS implementations. Specifically, we employ gumbel_muzero_policy—a policy improvement algorithm from "Policy improvement by planning with Gumbel" (ICLR 2022).

Why Gumbel MuZero?

Standard MCTS uses stochastic action selection at the root. Gumbel MuZero instead uses Sequential Halving with Gumbel noise to deterministically select which actions to explore:

The Recurrent Function

MCTS requires a model of the environment. For BinPack, we use the true environment dynamics (perfect model):

def recurrent_fn(model, rng_key, action, state):
    model_params, model_state = model
    
    # Unflatten: (B,) -> (B, 2) for (ems_id, item_id)
    action_pair = unflatten_action(action)
    
    # Step the environment
    next_state, next_ts = jax.vmap(env.step)(state, action_pair)
    
    # Neural network predictions at next state
    obs = next_ts.observation
    (logits, value), _ = forward.apply(
        model_params, model_state, obs
    )
    
    # Mask invalid actions
    valid_flat = safe_action_mask_flat(obs.action_mask)
    logits = mask_logits(logits, valid_flat)
    
    return mctx.RecurrentFnOutput(
        reward=next_ts.reward,
        discount=jnp.where(next_ts.discount == 0, 0.0, 1.0),
        prior_logits=logits,
        value=value,
    ), next_state
View recurrent_fn

Action Masking

The 2D action mask (E, I) must be flattened and passed to MCTS as invalid_actions. We ensure at least one action is valid to avoid numerical issues:

def safe_action_mask_flat(action_mask_2d):
    """Flatten (B, E, I) -> (B, E*I) with fallback."""
    flat = action_mask_2d.reshape((action_mask_2d.shape[0], -1))
    has_any = jnp.any(flat, axis=-1)
    
    # If no valid action, allow dummy action 0
    dummy = jax.nn.one_hot(
        jnp.zeros_like(has_any, dtype=jnp.int32), 
        num_actions
    ).astype(jnp.bool_)
    
    return jnp.where(has_any[:, None], flat, dummy)

Invoking the Search

policy_out = mctx.gumbel_muzero_policy(
    params=model,
    rng_key=key,
    root=mctx.RootFnOutput(
        prior_logits=root_logits,
        value=value,
        embedding=state,  # Env state as embedding
    ),
    recurrent_fn=recurrent_fn,
    num_simulations=32,
    invalid_actions=~valid_flat,
    qtransform=mctx.qtransform_completed_by_mix_value,
    gumbel_scale=1.0,
)

action = policy_out.action
action_weights = policy_out.action_weights  # Training target

Transformer Architecture

The policy-value network must process variable-length sets of EMS regions and items with complex relational structure. We use a Transformer-based architecture with cross-attention between EMS and item tokens.

Torso: Cross-Attention Between EMS and Items

The key insight: EMS-item compatibility is encoded in the action mask. We use this as an attention mask for cross-attention:

class BinPackTorso(hk.Module):
    def __call__(self, observation):
        # Embed EMS: (B, E, 6) -> (B, E, D)
        ems_leaves = jnp.stack(
            jax.tree_util.tree_leaves(observation.ems), axis=-1
        )
        ems_emb = hk.Linear(self.model_size)(ems_leaves)
        
        # Embed Items: (B, I, 3) -> (B, I, D)
        item_leaves = jnp.stack(
            jax.tree_util.tree_leaves(observation.items), axis=-1
        )
        items_emb = hk.Linear(self.model_size)(item_leaves)
        
        # Cross-attention masks from action_mask
        ems_cross_items = jnp.expand_dims(
            observation.action_mask, axis=-3
        )
        
        for block_id in range(self.num_layers):
            # Self-attention on EMS
            ems_emb = TransformerBlock(...)(
                ems_emb, ems_emb, ems_emb, ems_mask
            )
            # Self-attention on Items
            items_emb = TransformerBlock(...)(
                items_emb, items_emb, items_emb, items_mask
            )
            # Cross-attention: EMS ↔ Items
            ems_emb = TransformerBlock(...)(
                ems_emb, items_emb, items_emb, ems_cross_items
            )
            items_emb = TransformerBlock(...)(
                items_emb, ems_emb, ems_emb, items_cross_ems
            )
        
        return ems_emb, items_emb
View BinPackTorso

Policy Head: Outer Product

To produce logits over the joint action space (E × I), we compute an outer product between projected EMS and Item embeddings:

# Project to policy space
ems_h = hk.Linear(self.model_size)(ems_embeddings)     # (B, E, D)
items_h = hk.Linear(self.model_size)(items_embeddings) # (B, I, D)

# Outer product: (B, E, D) x (B, I, D) -> (B, E, I)
logits_2d = jnp.einsum("...ek,...ik->...ei", ems_h, items_h)

# Mask and flatten: (B, E, I) -> (B, E*I)
logits_2d = jnp.where(observation.action_mask, logits_2d, -1e9)
logits = logits_2d.reshape(batch_size, -1)

Value Head: Pooled Representations

# Sum-pool over valid EMS and available items
ems_sum = jnp.sum(ems_emb, axis=-2, where=ems_mask[..., None])
items_avail = observation.items_mask & ~observation.items_placed
items_sum = jnp.sum(items_emb, axis=-2, where=items_avail[..., None])

# MLP to scalar value in [0, 1]
joint = jnp.concatenate([ems_sum, items_sum], axis=-1)
v = hk.nets.MLP([D, D, 1])(joint)
v = jax.nn.sigmoid(jnp.squeeze(v, axis=-1))  # Utilization target

Architecture Summary

4 Transformer layers · 4 attention heads · key size 32 · model dimension D = 128 · MLP hidden units (256, 256)

Training: AlphaZero vs PPO

AlphaZero Training Loop

Each iteration:

  1. Self-play: Run batched episodes using MCTS (32 simulations/move)
  2. Compute targets: action_weights from MCTS, Monte Carlo returns for value
  3. Train: Minimize cross-entropy (policy) + MSE (value)
def loss_fn(params, state, batch):
    (logits, value), state = forward.apply(params, state, batch.obs)
    
    # Policy: match MCTS action distribution
    pol_loss = optax.softmax_cross_entropy(logits, batch.policy_tgt)
    
    # Value: match Monte Carlo returns
    v_loss = optax.l2_loss(value, batch.value_tgt)
    
    mask = batch.mask.astype(jnp.float32)
    pol_loss = jnp.sum(pol_loss * mask) / jnp.sum(mask)
    v_loss = jnp.sum(v_loss * mask) / jnp.sum(mask)
    
    return pol_loss + v_loss, (state, pol_loss, v_loss)

PPO Baseline

For comparison, we implement PPO with the same network architecture (minus MCTS):

# PPO clipped objective
ratio = jnp.exp(new_logp - old_logp)
surr1 = ratio * advantages
surr2 = jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -jnp.minimum(surr1, surr2).mean()

# Value clipping
v_clipped = old_v + jnp.clip(v - old_v, -clip_eps, clip_eps)
v_loss = jnp.maximum((v - returns)**2, (v_clipped - returns)**2).mean()

loss = policy_loss + 0.5 * v_loss - 0.01 * entropy
View PPO implementation

Results

Method Avg Return Eval (Greedy) Sims/Move
PPO 0.90 0.90
AlphaZero (nsim=8) 0.93 0.92 8
AlphaZero (nsim=32) 0.96 0.95 32
AlphaZero (nsim=64) 0.96 0.95 64
Why does MCTS help so much? Bin packing has sparse, delayed rewards—a seemingly good early placement can block better solutions. MCTS explicitly searches ahead to evaluate placement consequences.

Key Observations

Finding 1
Search provides stronger training signal. MCTS action weights encode multi-step lookahead. PPO's policy gradient only sees single-step feedback.
Finding 2
Diminishing returns beyond 32 simulations. The marginal gain from 32→64 sims is small, suggesting the bottleneck shifts to network capacity or training data diversity.
Finding 3
Greedy evaluation underperforms MCTS execution. The learned policy alone achieves ~0.95, but running MCTS at test time could push higher. This gap represents the value of search at inference.

Computational Cost

MCTS is significantly more expensive per environment step:

However, AlphaZero reaches 0.96 utilization in ~400 iterations, while PPO plateaus at 0.90 regardless of training length.

Hyperparameters

Parameter AlphaZero PPO
Learning rate 1e-3 3e-4
Batch size (selfplay) 1024 1024
Batch size (training) 4096 4096
Discount (γ) 1.0 1.0
GAE λ 0.95
PPO clip ε 0.2
Entropy coef 0.01
Max grad norm 1.0 1.0

References

  1. Silver et al. "Mastering the game of Go with deep neural networks and tree search" Nature 2016
  2. Schrittwieser et al. "Mastering Atari, Go, chess and shogi by planning with a learned model" Nature 2020
  3. Danihelka et al. "Policy improvement by planning with Gumbel" ICLR 2022
  4. InstaDeep. "Jumanji: A diverse suite of scalable RL environments in JAX" 2023
  5. DeepMind. "mctx: MCTS-in-JAX"