-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using torch.compile in Pyro models #2256
Comments
Hi @vitkl, |
The purpose is to enable general support for Pyro scvi-tools models. It is possible that some models benefit from this more than other models but it's good to have this option. Pyro adds additional challenges to using Re-implementation of models in numpyro is not always practical because i) numpyro doesn't cover all functionality and because ii) we observed in the past that JAX uses 2-4x of GPU memory for the same data size -meaning> less practical to use for larger datasets where every bit of GPU memory matters. |
I agree that speed-up is expected to be largely model-dependent and that scVI is small and might be a bad proxy. Adam and Martin experimented with torch.compile, however, only in the pytorch models. I would expect it's more straightforward to train the model/guide for one step (similar to our current load procedure) scvi-tools/scvi/module/base/_base_module.py Line 388 in 4965279
|
Do you suggest to modify self.module.on_load(self)
self.module._model = torch.compile(self.module.model)
self.module._guide = torch.compile(self.module.guide) Are |
As a proxy for compilation effect on cell2location, I can mention that our old theano+pymc3 implementation was 2-4 times faster for the same number of training steps. Would be great to see what happens here. A 2-4x speedup would be really nice. |
I tried it out on my side and got some cryptic error messages (it was on a private repo with a not published model though). My idea was to call self.train(max_steps=1) once and afterwards compile. So using the guide warmup by running a single train step. I'm happy to review if you have a PR. |
I will try your suggestion. Do I get this right that you suggest to def train(self, ...):
self.train(..., max_steps=1)
self.module._model = torch.compile(self.module.model)
self.module._guide = torch.compile(self.module.guide)
self.train(...) ? |
Yes, that's my understanding of how we do guide warmups for Pyro (e.g. during loading a trained model). I don't think pyro.clear_param_store() is necessary here. |
This is a good point. I will test this. Lets see what happens with cell2location. |
Looks like def MyModelClass(PyroSampleMixin, PyroSviTrainMixin, BaseModelClass):
def train_compiled(self, **kwargs):
import torch
self.train(**kwargs, max_steps=1)
self.module._model = torch.compile(self.module.model)
self.module._guide = torch.compile(self.module.guide)
self.train(**kwargs) The model and guide are successfully replaced:
Pytorch documentation says (https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html):
I wonder if this means that speedups only come for models that don't already have 100% GPU utilisation. Cell2location mainly uses very large full data batches. I also get errors if I attempt using amortised inference (using encoder NN as part of the guide). File /nfs/team283/vk7/software/miniconda3farm5/envs/cell2loc_env_2023/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py:1544, in ShapeGuardPrinter._print_Symbol(self, expr)
1538 def repr_symbol_to_source():
1539 return repr({
1540 symbol: [s.name() for s in sources]
1541 for symbol, sources in self.symbol_to_source.items()
1542 })
-> 1544 assert self.symbol_to_source.get(expr), (
1545 f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
1546 f"not in {repr_symbol_to_source()}. If this assert is failing, it could be "
1547 "due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
1548 )
1549 return self.source_ref(self.symbol_to_source[expr][0])
AssertionError: s2 (could be from ["L['msg']['infer']['prior']._batch_shape[0]"]) not in {s0: ["L['msg']['value'].size()[0]"], s1: ["L['msg']['value'].size()[1]", "L['msg']['value'].stride()[0]"], s5: [], s2: [], s4: [], s3: []}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665 |
Hi @adamgayoso and others (also cc @fritzo, @martinjankowiak @eb8680)
It would be great if the new
torch.compile
function could be used with the Pyro model and guide in scvi-tools.I am happy to contribute this functionality, however, I need your recommendations on what to do with the following problem. Suppose we create add
torch.compile
as shown below:The problem is that Pyro creates guide parameters when they are first needed - requiring these callbacks
scvi-tools/scvi/model/base/_pyromixin.py
Lines 19 to 71 in a210867
torch.compile(_guide)
should similarly be called only after the parameters are created.I see one solution to this. Run the following code
scvi-tools/scvi/model/base/_pyromixin.py
Lines 65 to 71 in a210867
model.train()
manually without using a callback after creating data loaders but before creatingTrainRunner
andTrainingPlan
.Then modify the training plan as follows:
What do you think about this? Do you have any better ideas on how to implement this?
The text was updated successfully, but these errors were encountered: