We apply AlphaZero-style MCTS to 3D bin packing using JAX, the mctx library, and a custom Transformer architecture. Our approach achieves 0.96 average utilization compared to 0.90 from PPO—a 6 percentage point improvement at identical network capacity, demonstrating the power of search-based planning for combinatorial optimization.
The 3D bin packing problem is a classic NP-hard combinatorial optimization challenge: given a set of rectangular items with dimensions (x, y, z) and a container of fixed size, pack as many items as possible
to maximize space utilization. Unlike the 1D knapsack where we only decide which items to include, 3D packing requires deciding where to place each item—and placement affects future available spaces.
Jumanji's BinPack-v2 environment formulates this as a sequential decision process. At each timestep, the agent selects:
The action space is the Cartesian product: A = E × I, where E = obs_num_ems = 40 and I = max_num_items = 20. This yields 800 possible actions per step.
| Field | Shape | Description |
|---|---|---|
ems |
(E, 6) | EMS coordinates (x1, x2, y1, y2, z1, z2) |
ems_mask |
(E,) | Valid EMS slots |
items |
(I, 3) | Item dimensions (x_len, y_len, z_len) |
items_mask |
(I,) | Valid item slots |
items_placed |
(I,) | Already-placed items |
action_mask |
(E, I) | Feasible (EMS, item) pairs |
The dense reward equals the volume utilization gained at each step. Episode returns sum to total utilization in [0, 1].
Pure MCTS struggles with large action spaces. AlphaZero's insight: use a neural network to provide prior probabilities over actions (guiding exploration) and value estimates (replacing random rollouts).
The neural network is trained to match the search policy (visit distribution) and value (Monte Carlo returns), creating a virtuous cycle: better networks → better search → better training targets → better networks.
We use DeepMind's mctx library, which provides JAX-native MCTS implementations. Specifically, we employ gumbel_muzero_policy—a policy improvement
algorithm from "Policy improvement by planning with Gumbel" (ICLR 2022).
Standard MCTS uses stochastic action selection at the root. Gumbel MuZero instead uses Sequential Halving with Gumbel noise to deterministically select which actions to explore:
softmax(prior_logits + completed_qvalues)MCTS requires a model of the environment. For BinPack, we use the true environment dynamics (perfect model):
def recurrent_fn(model, rng_key, action, state):
model_params, model_state = model
# Unflatten: (B,) -> (B, 2) for (ems_id, item_id)
action_pair = unflatten_action(action)
# Step the environment
next_state, next_ts = jax.vmap(env.step)(state, action_pair)
# Neural network predictions at next state
obs = next_ts.observation
(logits, value), _ = forward.apply(
model_params, model_state, obs
)
# Mask invalid actions
valid_flat = safe_action_mask_flat(obs.action_mask)
logits = mask_logits(logits, valid_flat)
return mctx.RecurrentFnOutput(
reward=next_ts.reward,
discount=jnp.where(next_ts.discount == 0, 0.0, 1.0),
prior_logits=logits,
value=value,
), next_state
View recurrent_fn
The 2D action mask (E, I) must be flattened and passed to MCTS as invalid_actions. We ensure at least one action is valid to avoid numerical issues:
def safe_action_mask_flat(action_mask_2d):
"""Flatten (B, E, I) -> (B, E*I) with fallback."""
flat = action_mask_2d.reshape((action_mask_2d.shape[0], -1))
has_any = jnp.any(flat, axis=-1)
# If no valid action, allow dummy action 0
dummy = jax.nn.one_hot(
jnp.zeros_like(has_any, dtype=jnp.int32),
num_actions
).astype(jnp.bool_)
return jnp.where(has_any[:, None], flat, dummy)
policy_out = mctx.gumbel_muzero_policy(
params=model,
rng_key=key,
root=mctx.RootFnOutput(
prior_logits=root_logits,
value=value,
embedding=state, # Env state as embedding
),
recurrent_fn=recurrent_fn,
num_simulations=32,
invalid_actions=~valid_flat,
qtransform=mctx.qtransform_completed_by_mix_value,
gumbel_scale=1.0,
)
action = policy_out.action
action_weights = policy_out.action_weights # Training target
The policy-value network must process variable-length sets of EMS regions and items with complex relational structure. We use a Transformer-based architecture with cross-attention between EMS and item tokens.
The key insight: EMS-item compatibility is encoded in the action mask. We use this as an attention mask for cross-attention:
class BinPackTorso(hk.Module):
def __call__(self, observation):
# Embed EMS: (B, E, 6) -> (B, E, D)
ems_leaves = jnp.stack(
jax.tree_util.tree_leaves(observation.ems), axis=-1
)
ems_emb = hk.Linear(self.model_size)(ems_leaves)
# Embed Items: (B, I, 3) -> (B, I, D)
item_leaves = jnp.stack(
jax.tree_util.tree_leaves(observation.items), axis=-1
)
items_emb = hk.Linear(self.model_size)(item_leaves)
# Cross-attention masks from action_mask
ems_cross_items = jnp.expand_dims(
observation.action_mask, axis=-3
)
for block_id in range(self.num_layers):
# Self-attention on EMS
ems_emb = TransformerBlock(...)(
ems_emb, ems_emb, ems_emb, ems_mask
)
# Self-attention on Items
items_emb = TransformerBlock(...)(
items_emb, items_emb, items_emb, items_mask
)
# Cross-attention: EMS ↔ Items
ems_emb = TransformerBlock(...)(
ems_emb, items_emb, items_emb, ems_cross_items
)
items_emb = TransformerBlock(...)(
items_emb, ems_emb, ems_emb, items_cross_ems
)
return ems_emb, items_emb
View BinPackTorso
To produce logits over the joint action space (E × I), we compute an outer product between projected EMS and Item embeddings:
# Project to policy space
ems_h = hk.Linear(self.model_size)(ems_embeddings) # (B, E, D)
items_h = hk.Linear(self.model_size)(items_embeddings) # (B, I, D)
# Outer product: (B, E, D) x (B, I, D) -> (B, E, I)
logits_2d = jnp.einsum("...ek,...ik->...ei", ems_h, items_h)
# Mask and flatten: (B, E, I) -> (B, E*I)
logits_2d = jnp.where(observation.action_mask, logits_2d, -1e9)
logits = logits_2d.reshape(batch_size, -1)
# Sum-pool over valid EMS and available items
ems_sum = jnp.sum(ems_emb, axis=-2, where=ems_mask[..., None])
items_avail = observation.items_mask & ~observation.items_placed
items_sum = jnp.sum(items_emb, axis=-2, where=items_avail[..., None])
# MLP to scalar value in [0, 1]
joint = jnp.concatenate([ems_sum, items_sum], axis=-1)
v = hk.nets.MLP([D, D, 1])(joint)
v = jax.nn.sigmoid(jnp.squeeze(v, axis=-1)) # Utilization target
Each iteration:
action_weights from MCTS, Monte Carlo returns for valuedef loss_fn(params, state, batch):
(logits, value), state = forward.apply(params, state, batch.obs)
# Policy: match MCTS action distribution
pol_loss = optax.softmax_cross_entropy(logits, batch.policy_tgt)
# Value: match Monte Carlo returns
v_loss = optax.l2_loss(value, batch.value_tgt)
mask = batch.mask.astype(jnp.float32)
pol_loss = jnp.sum(pol_loss * mask) / jnp.sum(mask)
v_loss = jnp.sum(v_loss * mask) / jnp.sum(mask)
return pol_loss + v_loss, (state, pol_loss, v_loss)
For comparison, we implement PPO with the same network architecture (minus MCTS):
# PPO clipped objective
ratio = jnp.exp(new_logp - old_logp)
surr1 = ratio * advantages
surr2 = jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -jnp.minimum(surr1, surr2).mean()
# Value clipping
v_clipped = old_v + jnp.clip(v - old_v, -clip_eps, clip_eps)
v_loss = jnp.maximum((v - returns)**2, (v_clipped - returns)**2).mean()
loss = policy_loss + 0.5 * v_loss - 0.01 * entropy
View PPO implementation
| Method | Avg Return | Eval (Greedy) | Sims/Move |
|---|---|---|---|
| PPO | 0.90 | 0.90 | — |
| AlphaZero (nsim=8) | 0.93 | 0.92 | 8 |
| AlphaZero (nsim=32) | 0.96 | 0.95 | 32 |
| AlphaZero (nsim=64) | 0.96 | 0.95 | 64 |
MCTS is significantly more expensive per environment step:
However, AlphaZero reaches 0.96 utilization in ~400 iterations, while PPO plateaus at 0.90 regardless of training length.
| Parameter | AlphaZero | PPO |
|---|---|---|
| Learning rate | 1e-3 | 3e-4 |
| Batch size (selfplay) | 1024 | 1024 |
| Batch size (training) | 4096 | 4096 |
| Discount (γ) | 1.0 | 1.0 |
| GAE λ | — | 0.95 |
| PPO clip ε | — | 0.2 |
| Entropy coef | — | 0.01 |
| Max grad norm | 1.0 | 1.0 |