Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pytree operations to simplify code in influence #592

Open
janosg opened this issue Jun 5, 2024 · 0 comments
Open

Use pytree operations to simplify code in influence #592

janosg opened this issue Jun 5, 2024 · 0 comments

Comments

@janosg
Copy link
Collaborator

janosg commented Jun 5, 2024

What are pytrees?

While reviewing #582, I got the feeling that we have quite a bit of code that re-implements the typical pytree operations as implemented in jax, optree or my instructional library pybaum.

The pytree operations solve the following problem in a general way: In math notation, we often need one-dimensional vectors but in code we want to represent things in richer data formats (e.g. dictionaries of arbitrary dimensional arrays). Prime examples are the parameters of a neural network. Pytrees are not a specific type themselves. For our purposes, any (nested) container of tensors or numbers is a pytree and the pytree operations are defined on it.

The most important operations are:

  • tree_flatten: Convert a pytree into a list of leaves and an object that stores the tree structure
  • tree_unflatten: The inverse of tree_flatten
  • tree_map: Apply a function to one or several pytrees
  • tree_reduce: Apply a reduction over all leaves in a pytree

A full list is here

What is considered a leave depends on the registry of containers, which can be extended by users. For example, if torch tensors are not registered containers, tree_flatten would convert a nested dict of tensors into a list of tensors. If torch tensors are registered, it would flatten a nested dict of tensors into a list of numbers.

pydvl code that could be removed or simplified

  • pydvl.influence.torch.base.DictBilinearForm
  • pydvl.influence.torch.base.TensorDictOperator
  • pydvl.influence.torch.util.reshape_vector_to_tensors
  • pydvl.influence.torch.util.align_structure
  • pydvl.influence.torch.util.align_with_model
  • pydvl.influence.torch.util.flatten_dimensions

Advantage of using pytree operations

  • They are very general and stay the same across project, whereas custom implementations of similar logic are often project specific.
  • They are well designed and by composing multiple pytree operations one can achieve a lot
  • They are optimized (especially in optree)
  • They can help us to achieve framework independence. By playing with the container registry, the operations stay the same for pytorch and jax models (even though jax models will typically have a more nested parameter representation)

Drawbacks of using pytree operations

  • We get an additional dependency
  • There is a learning curve to pytrees (but it's very much worth it!)
@janosg janosg mentioned this issue Jun 5, 2024
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant