Skip to content

Add JAX-like vmap operator #3397

@Friedrich-S

Description

@Friedrich-S

Feature description

JAX has the very powerful vmap feature (among others), which can be roughly summarized as lifting some provided function (all arrays/tensors and operations it uses) into a higher rank. There is work being done (slowly) to integrate this feature into PyTorch, which has semantics more similar to burn (as opposed to JAX).

Here is a brief Python example for those not familiar with JAX:

# Accepts tensors of shape [d_in,] and outputs [d_out,]
model = nn.Linear(d_in, d_out, ...)

# Accepts batch tensors of shape [16, d_in] and outputs (16, d_out)
batch_model = jax.vmap(model, axis_size=16)

As you can see, this is able to add batching support to a simple linear layer (could be used in MoE) without needing to change any existing code. And of course, you could wrap it in another vmap and support two batch dimensions (maybe for images? Not immediately useful, but it is possible).

Feature motivation

This is a very useful feature and a main reason why I am still primarily using JAX for ML work. A major use case is being able to write an implementation that works based on a single example, and then retroactively lifting it into a batch dimension, which makes the code much easier to reason about. There are many additional use cases, such as implementing a complex per-element matrix operation and being able to lift that into a batched operation, that might be more complex to write manually. Combined with the optimizer, it is able to compile a vmap'd dot-product into a matrix multiplication, for example.

(Optional) Suggest a Solution

I am honestly not sure, and imagine this to be very complex to implement. Python has the "advantage" of being highly dynamic and easily supporting reflection and other runtime trickery. In Rust, I imagine this would need to be done using a lot of macro magic, if at all possible.

It is totally understandable if this were to be closed as not planned.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions