From d4cc5e12cae0ab9cc1324eb06cd1692ca704a580 Mon Sep 17 00:00:00 2001 From: Ignacio Date: Mon, 27 May 2024 19:21:29 +0200 Subject: [PATCH] small changes --- qadence_libs/qinfo_tools/qfi.py | 8 ++++---- qadence_libs/qinfo_tools/qng.py | 4 ++-- tests/qinfo_tools/test_qng.py | 28 ++++++++++++++-------------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/qadence_libs/qinfo_tools/qfi.py b/qadence_libs/qinfo_tools/qfi.py index 2ba4cc0..03cb522 100644 --- a/qadence_libs/qinfo_tools/qfi.py +++ b/qadence_libs/qinfo_tools/qfi.py @@ -10,10 +10,10 @@ from qadence_libs.qinfo_tools.utils import hessian -def _symsqrt(A: Tensor) -> Tensor: - """Computes the square root of a Symmetric or Hermitian positive definite matrix. +def _positive_semidefinite_sqrt(A: Tensor) -> Tensor: + """Computes the square root of a real positive semi-definite matrix. - Code from https://github.com/pytorch/pytorch/issues/25481#issuecomment-1032789228 + Code taken from https://github.com/pytorch/pytorch/issues/25481#issuecomment-1032789228 """ L, Q = torch.linalg.eigh(A) zero = torch.zeros((), device=L.device, dtype=L.dtype) @@ -149,7 +149,7 @@ def get_quantum_fisher_spsa( qfi_mat_estimator = a_k * (iteration * previous_qfi_estimator + qfi_mat) # type: ignore # Get the positive-semidefinite version of the matrix for the update rule in QNG - qfi_mat_positive_sd = _symsqrt(torch.matmul(qfi_mat_estimator, qfi_mat_estimator)) + qfi_mat_positive_sd = _positive_semidefinite_sqrt(qfi_mat_estimator @ qfi_mat_estimator) qfi_mat_positive_sd = qfi_mat_positive_sd + beta * torch.eye(ovrlp_model.num_vparams) qfi_mat_positive_sd = qfi_mat_positive_sd / (1 + beta) # regularization diff --git a/qadence_libs/qinfo_tools/qng.py b/qadence_libs/qinfo_tools/qng.py index 8c227f5..7b05c8b 100644 --- a/qadence_libs/qinfo_tools/qng.py +++ b/qadence_libs/qinfo_tools/qng.py @@ -63,8 +63,8 @@ def __init__( if isinstance(model, TransformedModule): logger.warning( "The model is of type '. " - "Keep in mind that the QNG optimizer can only optimize circuit variational " - "parameter. Input and output shifting/scaling parameters will not be optimized." + "Keep in mind that the QNG optimizer can only optimize circuit " + "parameters. Input and output shifting/scaling parameters will not be optimized." ) # Retrieve the quantum model from the TransformedModule model = model.model diff --git a/tests/qinfo_tools/test_qng.py b/tests/qinfo_tools/test_qng.py index 6eae813..f6fcc18 100644 --- a/tests/qinfo_tools/test_qng.py +++ b/tests/qinfo_tools/test_qng.py @@ -51,6 +51,20 @@ def sin_dataset(samples: int) -> tuple[Tensor, Tensor]: DATASETS = [quadratic_dataset(samples), sin_dataset(samples)] +def test_parameter_ordering(basic_optim_model: QuantumCircuit) -> None: + model = basic_optim_model + model.reset_vparams(torch.rand((len(model.vparams)))) + vparams_torch = [p.data for p in model.parameters() if p.requires_grad] + vparams_qadence = [v.data for v in model.vparams.values()] + assert len(vparams_torch) == len(vparams_qadence) + msg = ( + "The ordering of the output of the `vparams()` method in QuantumModel" + + "and the `parameters()` method in Torch is not consistent" + + "for variational parameters." + ) + assert vparams_torch == vparams_qadence, msg + + @pytest.mark.parametrize("dataset", DATASETS) @pytest.mark.parametrize("optim_config", OPTIMIZERS_CONFIG) def test_optims( @@ -76,17 +90,3 @@ def test_optims( if config["approximation"] == FisherApproximation.SPSA: assert optimizer.state["state"]["iter"] == iters assert optimizer.state["state"]["qfi_estimator"] is not None - - -def test_parameter_ordering(basic_optim_model: QuantumCircuit) -> None: - model = basic_optim_model - model.reset_vparams(torch.rand((len(model.vparams)))) - vparams_torch = [p.data for p in model.parameters() if p.requires_grad] - vparams_qadence = [v.data for v in model.vparams.values()] - assert len(vparams_torch) == len(vparams_qadence) - msg = ( - "The ordering of the output of the `vparams()` method in QuantumModel" - + "and the `parameters()` method in Torch is not consistent" - + "for variational parameters." - ) - assert vparams_torch == vparams_qadence, msg