Implement Reactant.batch
function for better batching (vmap too!)
#180
Labels
Reactant.batch
function for better batching (vmap too!)
#180
Currently, the way batching is implemented is by replacing broadcasting with
enzyme.batch
op and tracing over the broadcasted code. The following example should just work:One inconvenient of Julia's broadcasting is that there is no way to specify the dimension over which to broadcast; it will just iterate over everything. Thus, users need to use
eachslice
for slicing over the desired dimension.I'm not sure if we would correctly then batch on the desired dimension in this case or that it would create some extra instructions... need to check it.
But it could be beneficial to have some similar functionality in one
batch
function which would be easier to correctly trace and users coming from Jax would be more familiarized. An example:The text was updated successfully, but these errors were encountered: