Skip to content

Commit

Permalink
Black style updates
Browse files Browse the repository at this point in the history
  • Loading branch information
BoxiLi committed Jun 13, 2024
1 parent 0e91030 commit 0a1e990
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
16 changes: 9 additions & 7 deletions src/krotov/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
objective, there are usually several different functionals whose minimization
achieve that objective.
"""

import copy
import itertools
import sys
Expand All @@ -16,6 +17,7 @@
import numpy as np
import qutip
from packaging.version import parse as parse_version

if parse_version(qutip.__version__) < parse_version("5"):
is_qutip5 = False
else:
Expand Down Expand Up @@ -332,7 +334,7 @@ def mesolve(
c_ops=c_ops,
e_ops=e_ops,
args=args,
**kwargs
**kwargs,
)

def propagate(
Expand Down Expand Up @@ -412,7 +414,7 @@ def propagate(
if len(e_ops) == 0:
result.states.append(state)
else:
for (i, oper) in enumerate(e_ops):
for i, oper in enumerate(e_ops):
result.expect[i].append(expect(oper, state))
controls = extract_controls([self])
pulses_mapping = extract_controls_mapping([self], controls)
Expand Down Expand Up @@ -441,7 +443,7 @@ def propagate(
if len(e_ops) == 0:
result.states.append(state)
else:
for (i, oper) in enumerate(e_ops):
for i, oper in enumerate(e_ops):
result.expect[i].append(expect(oper, state))
if not is_qutip5:
result.expect = [np.array(a) for a in result.expect]
Expand Down Expand Up @@ -662,7 +664,7 @@ def _plug_in_array_controls_as_func(H, controls, mapping, tlist):
H = _nested_list_shallow_copy(H)
T = tlist[-1]
nt = len(tlist)
for (control, control_mapping) in zip(controls, mapping):
for control, control_mapping in zip(controls, mapping):
if isinstance(control, np.ndarray):
for i in control_mapping:
# Use the same formula that QuTiP normally passes to Cython for
Expand Down Expand Up @@ -975,8 +977,8 @@ def gate_objectives(
# complexity (and make the repr of an Objective look nicer) by identifying
# this and setting the mapped_basis_states to the identical objects as the
# original basis_states
for (i, state) in enumerate(mapped_basis_states):
for (j, basis_state) in enumerate(basis_states):
for i, state in enumerate(mapped_basis_states):
for j, basis_state in enumerate(basis_states):
if state == basis_state:
mapped_basis_states[i] = basis_state
if liouville_states_set is None:
Expand Down Expand Up @@ -1037,7 +1039,7 @@ def gate_objectives(
if normalize_weights:
N = len(objectives)
weights = N * np.array(weights) / np.sum(weights)
for (i, weight) in _reversed_enumerate(weights):
for i, weight in _reversed_enumerate(weights):
weight = float(weight)
if weight < 0:
raise ValueError("weights must be greater than zero")
Expand Down
12 changes: 8 additions & 4 deletions src/krotov/propagators.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@
.. _Cython: https://cython.org
"""

from abc import ABC, abstractmethod

import numpy as np
import qutip
import scipy
import threadpoolctl
from packaging.version import parse as parse_version

if parse_version(qutip.__version__) < parse_version("5"):
is_qutip5 = False
from qutip.cy.spconvert import dense2D_to_fastcsr_fmode
Expand Down Expand Up @@ -262,7 +264,7 @@ def __call__(
unstack_columns(self._y),
dims=state.dims,
isherm=True,
dtype="csr"
dtype="csr",
)
else:
return qutip.Qobj(
Expand Down Expand Up @@ -303,7 +305,7 @@ def _initialize_data(self, L, rho, dt, c_ops, backwards):
if not (c_ops is None or len(c_ops) == 0):
# in principle, we could convert c_ops to a Lindbladian, here
raise NotImplementedError("c_ops not implemented")
for (i, spec) in enumerate(L):
for i, spec in enumerate(L):
if isinstance(spec, qutip.Qobj):
l_op = spec
l_coeff = 1
Expand All @@ -323,10 +325,12 @@ def _initialize_data(self, L, rho, dt, c_ops, backwards):
)
self._L_list = L_list
self._control_indices = control_indices

if rho.type == 'oper':
if is_qutip5:
self._y = unstack_columns(rho.full()).ravel('F') # initial state
self._y = unstack_columns(rho.full()).ravel(
'F'
) # initial state
else:
self._y = mat2vec(rho.full()).ravel('F') # initial state
else:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_dump_result.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the dump_result convergence routine"""

import copy
import os

Expand All @@ -7,11 +8,13 @@

import qutip
from packaging.version import parse as parse_version

if parse_version(qutip.__version__) < parse_version("5"):
oct_result_name = "oct_result_qutip4.dump"
else:
oct_result_name = "oct_result_qutip5.dump"


def incl_range(a, b, step=1):
e = 1 if step > 0 else -1
return range(a, b + e, step)
Expand Down

0 comments on commit 0a1e990

Please sign in to comment.