Why does JAX broadcast prefix as None values when the full structure contains None, instead of broadcasting prefix values
07:23 25 Mar 2026

I'm a bit puzzled by the following jax behavior (version 0.6.2).

Normally, if I broadcast a prefix tree into a structure tree, the value of the prefix is broadcast into the structure:

out_tree = jax.tree.broadcast(True, (False, False), is_leaf=lambda x: x is None) # Outputs (True, True)

However, if the structure tree contains None (marked as leaves), then the broadcast carries the None values instead:

out_tree = jax.tree.broadcast(True, (None, False), is_leaf=lambda x: x is None) # Outputs (None, True)

For the purpose of broadcasting, I was expecting the structure tree to be just that, a structure, and that the values of its leaves should not impact the broadcasting. What is the purpose of that behavior?

python tree jax