Broadcasting list of pytrees with each other without knowing the prefix a priori
19:59 13 Mar 2026

I want to broadcast a list of jax pytrees with each other, where it is assumed that one pytree is the prefix for all others (but we don't know which pytree is the prefix in advance).

The solution I found is the following:

import jax
import equinox as eqx

def broadcast_tree_pair(tree1, tree2): # Broadcasts two pytrees, trying both as a prefix
    try:
        return tree1, jax.tree.broadcast(tree2, tree1) # Try first broadcasting order
    except:
        return jax.tree.broadcast(tree1, tree2), tree2 # Try second broadcasting order

@eqx.filter_jit
def broadcast_trees(*trees):
    num = len(trees)
    if num == 0:
        return None
    elif num == 1:
        return trees[0]
    else:
        trees = list(trees)
        for k in range(1, num):
            for l in range(k):
                trees[l], trees[k] = broadcast_tree_pair(trees[l], trees[k]) # Broadcast every distinct pair
        return tuple(trees)

broadcast_trees(1, (2,3), (3, (1,2)), 4, (3, (1, [2,5]))) # Example use case

This works, but it's a bit involved for something that simple. Is there a simpler way to obtain the same result with common jax tools?

python array-broadcasting jax