Skip to content

Commit

Permalink
Merge pull request #66 from rochisha0/update-settings
Browse files Browse the repository at this point in the history
Update settings on change of backend
  • Loading branch information
Ericgig authored Aug 21, 2024
2 parents 2b0c79c + 1cd02cb commit 2db4223
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/qutip_jax/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import jax.numpy as jnp
import numpy as np
from qutip import settings
from qutip import SESolver, MCSolver, MESolver

__all__ = ["use_jax_backend"]
__all__ = ["set_as_default"]

def use_jax_backend():
settings.core['numpy_backend'] = jnp
def set_as_default(*, revert=False):
if revert:
settings.core['numpy_backend'] = np
settings.core['default_dtype'] = None
SESolver.solver_options['method'] = 'adams'
MESolver.solver_options['method'] = 'adams'
MCSolver.solver_options['method'] = 'adams'
else:
settings.core['numpy_backend'] = jnp
settings.core['default_dtype'] = 'jax'
SESolver.solver_options['method'] = 'diffrax'
MESolver.solver_options['method'] = 'diffrax'
MCSolver.solver_options['method'] = 'diffrax'

0 comments on commit 2db4223

Please sign in to comment.