forked from qutip/qutip
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request qutip#2513 from qutip/dev.major
Merge `jax` support PR in `dev.major` into `master`.
- Loading branch information
Showing
14 changed files
with
110 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add support for `jit` and `grad` in qutip.core.metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
This pull request introduces a new `NumpyBackend `class that enables dynamic selection of the numpy_backend used in `qutip`. The class facilitates switching between different numpy implementations ( `numpy` and `jax.numpy` mainly) based on the configuration specified in the `settings.core` dictionary. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Enable mcsolve with jax.grad using numpy_backend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
`clip` gives deprecation warning, that might be a problem in the future. Hence switch to `where` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from ..settings import settings | ||
|
||
|
||
class NumpyBackend: | ||
def _qutip_setting_backend(self, np): | ||
self._qt_np = np | ||
|
||
def __getattr__(self, name): | ||
return getattr(self._qt_np, name) | ||
|
||
|
||
# Initialize the numpy backend | ||
np = NumpyBackend() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pytest | ||
import numpy | ||
from unittest.mock import Mock | ||
|
||
from qutip.core.numpy_backend import np | ||
from qutip import CoreOptions | ||
|
||
# Mocking JAX to demonstrate backend switching | ||
mock_jax = Mock() | ||
mock_np = numpy | ||
|
||
|
||
class TestNumpyBackend: | ||
def test_getattr_numpy(self): | ||
with CoreOptions(numpy_backend=mock_np): | ||
assert np.sum([1, 2, 3]) == numpy.sum([1, 2, 3]) | ||
assert np.sum is numpy.sum | ||
|
||
def test_getattr_jax(self): | ||
with CoreOptions(numpy_backend=mock_jax): | ||
mock_jax.sum = Mock(return_value="jax_sum") | ||
assert np.sum([1, 2, 3]) == "jax_sum" |
Oops, something went wrong.