Skip to content

Commit

Permalink
enable mcsolve with jax.grad
Browse files Browse the repository at this point in the history
  • Loading branch information
rochisha0 committed Jul 22, 2024
1 parent 9dfc864 commit f7444b3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion qutip/solver/mcsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

__all__ = ['mcsolve', "MCSolver"]

import numpy as np
from ..core.numpy_backend import np
from numpy.typing import ArrayLike
from numpy.random import SeedSequence
from ..core import QobjEvo, spre, spost, Qobj, unstack_columns, qzero_like
Expand Down
16 changes: 8 additions & 8 deletions qutip/solver/multitraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from time import time
from .solver_base import Solver
from ..core import QobjEvo, Qobj
import numpy as np
from ..core.numpy_backend import np
from numpy.typing import ArrayLike
from numpy.random import SeedSequence
from numpy.random import SeedSequence, default_rng
from numbers import Number
from typing import Any, Callable
import bisect
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, rhs, *, options=None):
else:
raise TypeError("The system should be a QobjEvo")
self.options = options
self.seed_sequence = np.random.SeedSequence()
self.seed_sequence = SeedSequence()
self._integrator = self._get_integrator()
self._state_metadata = {}
self.stats = self._initialize_stats()
Expand Down Expand Up @@ -360,15 +360,15 @@ def _read_seed(self, seed, ntraj):
"""
if seed is None:
seeds = self.seed_sequence.spawn(ntraj)
elif isinstance(seed, np.random.SeedSequence):
elif isinstance(seed, SeedSequence):
seeds = seed.spawn(ntraj)
elif not isinstance(seed, list):
seeds = np.random.SeedSequence(seed).spawn(ntraj)
seeds = SeedSequence(seed).spawn(ntraj)
elif len(seed) >= ntraj:
seeds = [
seed_ if (isinstance(seed_, np.random.SeedSequence)
seed_ if (isinstance(seed_, SeedSequence)
or hasattr(seed_, 'random'))
else np.random.SeedSequence(seed_)
else SeedSequence(seed_)
for seed_ in seed[:ntraj]
]
else:
Expand All @@ -391,7 +391,7 @@ def _get_generator(self, seed):
bit_gen = getattr(np.random, self.options['bitgenerator'])
generator = np.random.Generator(bit_gen(seed))
else:
generator = np.random.default_rng(seed)
generator = default_rng(seed)
return generator


Expand Down
2 changes: 1 addition & 1 deletion qutip/solver/multitrajresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from typing import TypedDict
import numpy as np
from ..core.numpy_backend import np

from copy import copy

Expand Down
2 changes: 1 addition & 1 deletion qutip/solver/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from typing import TypedDict, Any, Callable
import numpy as np
from ..core.numpy_backend import np
from numpy.typing import ArrayLike
from ..core import Qobj, QobjEvo, expect

Expand Down

0 comments on commit f7444b3

Please sign in to comment.