Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
inafergra committed May 27, 2024
1 parent 29cf30d commit d4cc5e1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
8 changes: 4 additions & 4 deletions qadence_libs/qinfo_tools/qfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from qadence_libs.qinfo_tools.utils import hessian


def _symsqrt(A: Tensor) -> Tensor:
"""Computes the square root of a Symmetric or Hermitian positive definite matrix.
def _positive_semidefinite_sqrt(A: Tensor) -> Tensor:
"""Computes the square root of a real positive semi-definite matrix.
Code from https://github.com/pytorch/pytorch/issues/25481#issuecomment-1032789228
Code taken from https://github.com/pytorch/pytorch/issues/25481#issuecomment-1032789228
"""
L, Q = torch.linalg.eigh(A)
zero = torch.zeros((), device=L.device, dtype=L.dtype)
Expand Down Expand Up @@ -149,7 +149,7 @@ def get_quantum_fisher_spsa(
qfi_mat_estimator = a_k * (iteration * previous_qfi_estimator + qfi_mat) # type: ignore

# Get the positive-semidefinite version of the matrix for the update rule in QNG
qfi_mat_positive_sd = _symsqrt(torch.matmul(qfi_mat_estimator, qfi_mat_estimator))
qfi_mat_positive_sd = _positive_semidefinite_sqrt(qfi_mat_estimator @ qfi_mat_estimator)
qfi_mat_positive_sd = qfi_mat_positive_sd + beta * torch.eye(ovrlp_model.num_vparams)
qfi_mat_positive_sd = qfi_mat_positive_sd / (1 + beta) # regularization

Expand Down
4 changes: 2 additions & 2 deletions qadence_libs/qinfo_tools/qng.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
if isinstance(model, TransformedModule):
logger.warning(
"The model is of type '<class TransformedModule>. "
"Keep in mind that the QNG optimizer can only optimize circuit variational "
"parameter. Input and output shifting/scaling parameters will not be optimized."
"Keep in mind that the QNG optimizer can only optimize circuit "
"parameters. Input and output shifting/scaling parameters will not be optimized."
)
# Retrieve the quantum model from the TransformedModule
model = model.model
Expand Down
28 changes: 14 additions & 14 deletions tests/qinfo_tools/test_qng.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ def sin_dataset(samples: int) -> tuple[Tensor, Tensor]:
DATASETS = [quadratic_dataset(samples), sin_dataset(samples)]


def test_parameter_ordering(basic_optim_model: QuantumCircuit) -> None:
model = basic_optim_model
model.reset_vparams(torch.rand((len(model.vparams))))
vparams_torch = [p.data for p in model.parameters() if p.requires_grad]
vparams_qadence = [v.data for v in model.vparams.values()]
assert len(vparams_torch) == len(vparams_qadence)
msg = (
"The ordering of the output of the `vparams()` method in QuantumModel"
+ "and the `parameters()` method in Torch is not consistent"
+ "for variational parameters."
)
assert vparams_torch == vparams_qadence, msg


@pytest.mark.parametrize("dataset", DATASETS)
@pytest.mark.parametrize("optim_config", OPTIMIZERS_CONFIG)
def test_optims(
Expand All @@ -76,17 +90,3 @@ def test_optims(
if config["approximation"] == FisherApproximation.SPSA:
assert optimizer.state["state"]["iter"] == iters
assert optimizer.state["state"]["qfi_estimator"] is not None


def test_parameter_ordering(basic_optim_model: QuantumCircuit) -> None:
model = basic_optim_model
model.reset_vparams(torch.rand((len(model.vparams))))
vparams_torch = [p.data for p in model.parameters() if p.requires_grad]
vparams_qadence = [v.data for v in model.vparams.values()]
assert len(vparams_torch) == len(vparams_qadence)
msg = (
"The ordering of the output of the `vparams()` method in QuantumModel"
+ "and the `parameters()` method in Torch is not consistent"
+ "for variational parameters."
)
assert vparams_torch == vparams_qadence, msg

0 comments on commit d4cc5e1

Please sign in to comment.