I'm a bit puzzled by the following jax behavior (version 0.6.2).
Normally, if I broadcast a prefix tree into a structure tree, the value of the prefix is broadcast into the structure:
out_tree = jax.tree.broadcast(True, (False, False), is_leaf=lambda x: x is None) # Outputs (True, True)
However, if the structure tree contains None (marked as leaves), then the broadcast carries the None values instead:
out_tree = jax.tree.broadcast(True, (None, False), is_leaf=lambda x: x is None) # Outputs (None, True)
For the purpose of broadcasting, I was expecting the structure tree to be just that, a structure, and that the values of its leaves should not impact the broadcasting. What is the purpose of that behavior?