Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update documentation in AbstractDataset to show proper creation of a custom pytorch dataset #255

Open
mjo22 opened this issue Aug 8, 2024 · 1 comment
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@mjo22
Copy link
Owner

mjo22 commented Aug 8, 2024

Most importantly we need to update the __getitem__, since we cannot return arbitrary pytrees in torch.utils.data.Datasets.

However, this will also involve thinking about how to load things so that we do not make unnecessary array copies. Taking the RelionDataset for example, some questions are

  1. Should we be explicitly converting between JAX arrays are torch tensors in the __getitem__? This would give us the control to make sure conversion is copy-free, on either GPU or CPU (see https://jax.readthedocs.io/en/latest/jax.dlpack.html)
  2. Do we need an is_cpu_array boolean in the RelionDataset to force JAX arrays to be read on the CPU? We definitely do not want to move to the GPU and back to the CPU unnecessarily.
@mjo22
Copy link
Owner Author

mjo22 commented Aug 9, 2024

Progress on point 2 in #257. We should test that things actually work out to be copy-free. It may also be possible to go from GPU jax to GPU torch in a copy-free way using this.

@mjo22 mjo22 added the documentation Improvements or additions to documentation label Sep 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants