Skip to content

Commit

Permalink
create a single function
Browse files Browse the repository at this point in the history
  • Loading branch information
rochisha0 committed Aug 21, 2024
1 parent b8117b4 commit 1cd02cb
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/qutip_jax/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
from qutip import settings
from qutip import SESolver, MCSolver, MESolver

__all__ = ["use_jax_backend"]

def use_jax_backend():
settings.core['numpy_backend'] = jnp
settings.core['default_dtype'] = 'jax'
SESolver.solver_options['method'] = 'diffrax'
MESolver.solver_options['method'] = 'diffrax'
MCSolver.solver_options['method'] = 'diffrax'


def use_numpy_backend():
settings.core['numpy_backend'] = np
settings.core['default_dtype'] = None
SESolver.solver_options['method'] = 'adams'
MESolver.solver_options['method'] = 'adams'
MCSolver.solver_options['method'] = 'adams'
__all__ = ["set_as_default"]

def set_as_default(*, revert=False):
if revert:
settings.core['numpy_backend'] = np
settings.core['default_dtype'] = None
SESolver.solver_options['method'] = 'adams'
MESolver.solver_options['method'] = 'adams'
MCSolver.solver_options['method'] = 'adams'
else:
settings.core['numpy_backend'] = jnp
settings.core['default_dtype'] = 'jax'
SESolver.solver_options['method'] = 'diffrax'
MESolver.solver_options['method'] = 'diffrax'
MCSolver.solver_options['method'] = 'diffrax'

0 comments on commit 1cd02cb

Please sign in to comment.