You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While reviewing #582, I got the feeling that we have quite a bit of code that re-implements the typical pytree operations as implemented in jax, optree or my instructional library pybaum.
The pytree operations solve the following problem in a general way: In math notation, we often need one-dimensional vectors but in code we want to represent things in richer data formats (e.g. dictionaries of arbitrary dimensional arrays). Prime examples are the parameters of a neural network. Pytrees are not a specific type themselves. For our purposes, any (nested) container of tensors or numbers is a pytree and the pytree operations are defined on it.
The most important operations are:
tree_flatten: Convert a pytree into a list of leaves and an object that stores the tree structure
tree_unflatten: The inverse of tree_flatten
tree_map: Apply a function to one or several pytrees
tree_reduce: Apply a reduction over all leaves in a pytree
What is considered a leave depends on the registry of containers, which can be extended by users. For example, if torch tensors are not registered containers, tree_flatten would convert a nested dict of tensors into a list of tensors. If torch tensors are registered, it would flatten a nested dict of tensors into a list of numbers.
They can help us to achieve framework independence. By playing with the container registry, the operations stay the same for pytorch and jax models (even though jax models will typically have a more nested parameter representation)
Drawbacks of using pytree operations
We get an additional dependency
There is a learning curve to pytrees (but it's very much worth it!)
The text was updated successfully, but these errors were encountered:
What are pytrees?
While reviewing #582, I got the feeling that we have quite a bit of code that re-implements the typical
pytree
operations as implemented in jax, optree or my instructional library pybaum.The pytree operations solve the following problem in a general way: In math notation, we often need one-dimensional vectors but in code we want to represent things in richer data formats (e.g. dictionaries of arbitrary dimensional arrays). Prime examples are the parameters of a neural network. Pytrees are not a specific type themselves. For our purposes, any (nested) container of tensors or numbers is a pytree and the pytree operations are defined on it.
The most important operations are:
tree_flatten
: Convert a pytree into a list of leaves and an object that stores the tree structuretree_unflatten
: The inverse oftree_flatten
tree_map
: Apply a function to one or several pytreestree_reduce
: Apply a reduction over all leaves in a pytreeA full list is here
What is considered a leave depends on the registry of containers, which can be extended by users. For example, if torch tensors are not registered containers,
tree_flatten
would convert a nested dict of tensors into a list of tensors. If torch tensors are registered, it would flatten a nested dict of tensors into a list of numbers.pydvl code that could be removed or simplified
pydvl.influence.torch.base.DictBilinearForm
pydvl.influence.torch.base.TensorDictOperator
pydvl.influence.torch.util.reshape_vector_to_tensors
pydvl.influence.torch.util.align_structure
pydvl.influence.torch.util.align_with_model
pydvl.influence.torch.util.flatten_dimensions
Advantage of using pytree operations
Drawbacks of using pytree operations
The text was updated successfully, but these errors were encountered: