Skip to content

Commit

Permalink
Merge pull request #61 from cirKITers/inputs-expr
Browse files Browse the repository at this point in the history
optional none input
  • Loading branch information
majafranz authored Nov 15, 2024
2 parents 27f22b2 + 9bbe5ff commit 46e3cb4
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "qml-essentials"
version = "0.1.16"
version = "0.1.17"
description = ""
authors = ["Melvin Strobl <[email protected]>", "Maja Franz <[email protected]>"]
maintainers = ["Melvin Strobl <[email protected]>"]
Expand Down
32 changes: 25 additions & 7 deletions qml_essentials/expressibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ def _sample_state_fidelities(
"""
rng = np.random.default_rng(seed)

# Number of input samples
# Generate random parameter sets
# We need two sets of parameters, as we are computing fidelities for a
# pair of random state vectors
model.initialize_params(rng=rng, repeat=n_samples * 2)

n_x_samples = len(x_samples)

# Initialize array to store fidelities
fidelities: np.ndarray = np.zeros((n_x_samples, n_samples))

# Generate random parameter sets
# We need two sets of parameters, as we are computing fidelities for a
# pair of random state vectors
model.initialize_params(rng=rng, repeat=n_samples * 2)
# Batch input samples and parameter sets for efficient computation
x_samples_batched: np.ndarray = x_samples.reshape(1, -1).repeat(
n_samples * 2, axis=0
Expand Down Expand Up @@ -70,6 +70,7 @@ def _sample_state_fidelities(
)
** 2
)
# TODO: abs instead?
fidelities[idx] = np.real(fidelity)

return fidelities
Expand All @@ -82,6 +83,7 @@ def state_fidelities(
n_input_samples: int,
input_domain: List[float],
model: Model,
scale: bool = False,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Expand All @@ -94,14 +96,22 @@ def state_fidelities(
n_input_samples (int): Number of input samples.
input_domain (List[float]): Input domain.
model (Callable): Function that models the quantum circuit.
scale (bool): Whether to scale the number of samples and bins.
kwargs (Any): Additional keyword arguments for the model function.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing the
input samples, bin edges, and histogram values.
"""
if scale:
n_samples = np.power(2, model.n_qubits) * n_samples
n_bins = model.n_qubits * n_bins

x = np.linspace(*input_domain, n_input_samples, requires_grad=False)
if input_domain is None or n_input_samples is None or n_input_samples == 0:
x = np.zeros((1))
n_input_samples = 1
else:
x = np.linspace(*input_domain, n_input_samples, requires_grad=False)

fidelities = Expressibility._sample_state_fidelities(
x_samples=x,
Expand All @@ -119,6 +129,9 @@ def state_fidelities(

z = z / n_samples

if z.shape[0] == 1:
z = z.flatten()

return x, y, z

@staticmethod
Expand Down Expand Up @@ -168,6 +181,7 @@ def haar_integral(
n_qubits: int,
n_bins: int,
cache: bool = True,
scale: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculates theoretical probability density function for random Haar states
Expand All @@ -178,17 +192,21 @@ def haar_integral(
n_qubits (int): number of qubits in the quantum system
n_bins (int): number of histogram bins
cache (bool): whether to cache the haar integral
scale (bool): whether to scale the number of bins
Returns:
Tuple[np.ndarray, np.ndarray]:
- x component (bins): the input domain
- y component (probabilities): the haar probability density
funtion for random Haar states
"""
if scale:
n_bins = n_qubits * n_bins

x = np.linspace(0, 1, n_bins)

if cache:
name = f"haar_{n_qubits}q_{n_bins}s.npy"
name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy"

cache_folder = ".cache"
if not os.path.exists(cache_folder):
Expand Down
5 changes: 3 additions & 2 deletions qml_essentials/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,10 @@ def _forward(
inputs=inputs,
)

if isinstance(result, list):
result = np.stack(result)

if self.execution_type == "expval" and self.output_qubit == -1:
if isinstance(result, list):
result = np.stack(result)

# Calculating mean value after stacking, to not
# discard gradient information
Expand Down
33 changes: 33 additions & 0 deletions tests/test_expressiblity.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,36 @@ def test_expressibility() -> None:
), f"Expressibility is not {test_case['result']}\
for circuit ansatz {test_case['circuit_type']}.\
Was {kl_dist} instead"


@pytest.mark.unittest
@pytest.mark.expensive
def test_scaling() -> None:
model = Model(
n_qubits=2,
n_layers=1,
circuit_type="Circuit_1",
)

_, _, z = Expressibility.state_fidelities(
seed=1000,
n_bins=4,
n_samples=10,
n_input_samples=0,
input_domain=[0, 2 * np.pi],
model=model,
scale=True,
)

assert z.shape == (8,)

_, y = Expressibility.haar_integral(
n_qubits=model.n_qubits,
n_bins=4,
cache=False,
scale=True,
)

assert y.shape == (8,)

_ = Expressibility.kullback_leibler_divergence(z, y)

0 comments on commit 46e3cb4

Please sign in to comment.