Broadcasting list of pytrees with each other without knowing the prefix a priori
I want to broadcast a list of jax pytrees with each other, where it is assumed that one pytree is the prefix for all others (but we don't know which pytree is the prefix in advance).
The solution I found is the following:
import jax
import equinox as eqx
def broadcast_tree_pair(tree1, tree2): # Broadcasts two pytrees, trying both as a prefix
try:
return tree1, jax.tree.broadcast(tree2, tree1) # Try first broadcasting order
except:
return jax.tree.broadcast(tree1, tree2), tree2 # Try second broadcasting order
@eqx.filter_jit
def broadcast_trees(*trees):
num = len(trees)
if num == 0:
return None
elif num == 1:
return trees[0]
else:
trees = list(trees)
for k in range(1, num):
for l in range(k):
trees[l], trees[k] = broadcast_tree_pair(trees[l], trees[k]) # Broadcast every distinct pair
return tuple(trees)
broadcast_trees(1, (2,3), (3, (1,2)), 4, (3, (1, [2,5]))) # Example use case
This works, but it's a bit involved for something that simple. Is there a simpler way to obtain the same result with common jax tools?