Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints #3

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
bb7c7d7
:label: Annotate bitops.py
EarlMilktea Dec 22, 2023
a6d6f28
:art: Remove len(arr.shape)
EarlMilktea Dec 22, 2023
2298e62
:technologist: Add typeutil
EarlMilktea Mar 3, 2024
bea91a1
:construction: Annotate state_vector/type.py
EarlMilktea Mar 3, 2024
aba0838
:construction: Annotate operator/type.py
EarlMilktea Mar 3, 2024
bd679cd
:safety_vest: Add type validation
EarlMilktea Mar 3, 2024
a24b604
:safety_vest: Add eincheck.py
EarlMilktea Mar 3, 2024
20e1d70
:construction: Annotate trace.py
EarlMilktea Mar 3, 2024
4f3d619
:art: Use np.real to use implicit cast
EarlMilktea Mar 3, 2024
8a2680f
:rotating_light: Cast to fix type
EarlMilktea Mar 3, 2024
d71f151
:construction: Annotate state_vector module
EarlMilktea Mar 3, 2024
119172f
:art: Add OPAutomaton
EarlMilktea Mar 3, 2024
65ee3c4
:art: Remove inverse slice
EarlMilktea Mar 3, 2024
ebc5473
:boom: Use OPAutomaton instead of list
EarlMilktea Mar 3, 2024
e15802b
:construction: Annotate operator module
EarlMilktea Mar 3, 2024
619f24a
:safety_vest: Add string overload
EarlMilktea Mar 3, 2024
94d0e28
:safety_vest: Update mps as ABC
EarlMilktea Mar 3, 2024
29c6fa0
:bug: Change index type
EarlMilktea Mar 3, 2024
e188a1c
:safety_vest: Add typecheck
EarlMilktea Mar 3, 2024
ef48b6f
:bug: Add default value
EarlMilktea Mar 3, 2024
1c3462b
:bug: Expose members
EarlMilktea Mar 3, 2024
8e7d8b1
:construction: Annotate mps module
EarlMilktea Mar 3, 2024
9d34214
:art: Improve flow
EarlMilktea Mar 3, 2024
d0033fe
:wheelchair: Add OPSeqElement
EarlMilktea Mar 3, 2024
213bdfe
:construction: Annotate dispatch.py
EarlMilktea Mar 3, 2024
c7c3d9d
:construction: Annotate alg module
EarlMilktea Mar 3, 2024
1a9f91e
:heavy_plus_sign: Add typing-extensions
EarlMilktea Mar 3, 2024
d4c7497
:fire: Remove OPAutomaton
EarlMilktea Mar 3, 2024
4754b41
:bug: Add default arg
EarlMilktea Mar 3, 2024
355d2bc
:art: Remove np.einsum
EarlMilktea Mar 3, 2024
225d13b
:art: Remove untyped list
EarlMilktea Mar 3, 2024
1578b7a
:safety_vest: Annotate fallfack
EarlMilktea Mar 3, 2024
817e52b
:bug: Use relative import
EarlMilktea Mar 3, 2024
91e390a
:white_check_mark: Update tests
EarlMilktea Mar 3, 2024
6556fa5
:technologist: Remove np.einsum
EarlMilktea Mar 3, 2024
c242152
:label: Add py.typed
EarlMilktea Mar 3, 2024
b8eacf9
:bug: Add __future__
EarlMilktea Mar 3, 2024
5debf73
:heavy_minus_sign: Import from typing
EarlMilktea Mar 3, 2024
66214ec
:heavy_plus_sign: Move typing-extensions to deps
EarlMilktea Mar 3, 2024
f7538ae
:bug: Add mps compatibility
EarlMilktea Mar 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ build-backend = "setuptools.build_meta"
name = "qailo"
description = "Simplest Quantum Circuit Simulator"
readme = "README.md"
authors = [
{ name = "Synge Todo", email = "[email protected]" },
]
authors = [{ name = "Synge Todo", email = "[email protected]" }]
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
]
classifiers = ["Programming Language :: Python :: 3"]
dynamic = ["version"]
dependencies = ["matplotlib", "numpy"]
dependencies = ["matplotlib", "numpy", "typing-extensions"]

[tool.setuptools.dynamic]
version = { attr = "qailo._version.version" }
Expand All @@ -25,6 +21,9 @@ write_to = "src/qailo/_version.py"
[tool.setuptools]
package-dir = { "" = "src" }

[tool.setuptools.package-data]
"qailo" = ["py.typed"]

[project.optional-dependencies]
dev = ["pytest", "black", "ruff"]

Expand Down
21 changes: 12 additions & 9 deletions src/qailo/alg/qft.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
# ref: https://learn.qiskit.org/course/ch-algorithms/quantum-fourier-transform

from __future__ import annotations

import numpy as np

import qailo as q
from qailo.util.helpertype import OPSeqElement


def qft_rotations_seq(n):
seq = []
def qft_rotations_seq(n: int) -> list[OPSeqElement]:
seq: list[OPSeqElement] = []
if n == 0:
return seq
n -= 1
# print(f"H on [{n}]")
seq.append([q.op.h(), [n]])
seq.append(OPSeqElement(q.op.h(), [n]))
for p in range(n):
# print(f"CP(pi/{2**(n-p)} on [{p}, {n}]")
seq.append([q.op.cp(np.pi / 2 ** (n - p)), [p, n]])
seq.append(OPSeqElement(q.op.cp(np.pi / 2 ** (n - p)), [p, n]))
seq += qft_rotations_seq(n)
return seq


def swap_registers_seq(n):
seq = []
def swap_registers_seq(n: int) -> list[OPSeqElement]:
seq: list[OPSeqElement] = []
for p in range(n // 2):
# print(f"swap on [{p}, {n-p-1}]")
seq.append([q.op.swap(), [p, n - p - 1]])
seq.append(OPSeqElement(q.op.swap(), [p, n - p - 1]))
return seq


def qft_seq(n):
def qft_seq(n: int) -> list[OPSeqElement]:
"""QFT on the first n qubits in circuit"""
return qft_rotations_seq(n) + swap_registers_seq(n)


def inverse_qft_seq(n):
def inverse_qft_seq(n: int) -> list[OPSeqElement]:
return q.op.inverse_seq(qft_seq(n))
15 changes: 14 additions & 1 deletion src/qailo/alg/qpe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
from __future__ import annotations

from copy import deepcopy
from typing import overload

import numpy as np
import numpy.typing as npt

import qailo as q
from qailo.mps.type import mps


@overload
def qpe(n: int, u: npt.NDArray, v: npt.NDArray) -> npt.NDArray: ...


@overload
def qpe(n: int, u: npt.NDArray, v: mps) -> mps: ...


def qpe(n, u, v):
def qpe(n: int, u: npt.NDArray, v: npt.NDArray | mps) -> npt.NDArray | mps:
m = q.num_qubits(u)
w = deepcopy(v)
assert q.num_qubits(w) == m + n
Expand Down
49 changes: 42 additions & 7 deletions src/qailo/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
from __future__ import annotations

from typing import Container, Iterable, overload

import numpy as np
import numpy.typing as npt

from . import mps
from . import operator as op
from . import state_vector as sv
from .mps import type as mpstype
from .util.helpertype import OPSeqElement


@overload
def apply(
v: npt.NDArray, p: npt.NDArray, pos: list[int] | None = None
) -> npt.NDArray: ...


@overload
def apply(
v: mpstype.mps, p: npt.NDArray, pos: list[int] | None = None
) -> mpstype.mps: ...


def apply(v, p, pos=None):
def apply(
v: npt.NDArray | mpstype.mps, p: npt.NDArray, pos: list[int] | None = None
) -> npt.NDArray | mpstype.mps:
if sv.is_state_vector(v):
v = sv.apply(v, p, pos)
elif mps.is_mps(v):
Expand All @@ -15,7 +36,17 @@ def apply(v, p, pos=None):
return v


def apply_seq(v, seq):
@overload
def apply_seq(v: npt.NDArray, seq: Iterable[OPSeqElement]) -> npt.NDArray: ...


@overload
def apply_seq(v: mpstype.mps, seq: Iterable[OPSeqElement]) -> mpstype.mps: ...


def apply_seq(
v: npt.NDArray | mpstype.mps, seq: Iterable[OPSeqElement]
) -> npt.NDArray | mpstype.mps:
if sv.is_state_vector(v):
v = sv.apply_seq(v, seq)
elif mps.is_mps(v):
Expand All @@ -25,14 +56,16 @@ def apply_seq(v, seq):
return v


def norm(v):
def norm(v: npt.NDArray | mpstype.mps) -> float:
if sv.is_state_vector(v):
return np.linalg.norm(v)
return float(np.linalg.norm(v))
elif mps.is_mps(v):
return v._norm()
else:
assert False


def num_qubits(v):
def num_qubits(v: npt.NDArray | mpstype.mps) -> int:
if sv.is_state_vector(v):
return sv.num_qubits(v)
elif op.is_operator(v):
Expand All @@ -42,15 +75,17 @@ def num_qubits(v):
assert False


def probability(v, pos=None):
def probability(
v: npt.NDArray | mpstype.mps, pos: Container[int] | None = None
) -> npt.NDArray:
if sv.is_state_vector(v):
return sv.probability(v, pos)
elif mps.is_mps(v):
return sv.probability(mps.state_vector(v), pos)
assert False


def vector(v, c=None):
def vector(v: npt.NDArray | mpstype.mps, c: list | None = None) -> npt.NDArray:
if sv.is_state_vector(v):
return sv.vector(v, c)
elif mps.is_mps(v):
Expand Down
18 changes: 12 additions & 6 deletions src/qailo/mps/apply.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

from copy import deepcopy
from typing import Iterable, Sequence

import numpy.typing as npt

from ..operator import type as op
from ..operator.swap import swap
from ..util.helpertype import OPSeqElement
from . import type as mps


def _swap_tensors(m, s):
def _swap_tensors(m: mps.mps, s: int) -> None:
"""
swap neighboring two tensors at s and s+1
"""
Expand All @@ -16,7 +22,7 @@ def _swap_tensors(m, s):
m.t2q[s], m.t2q[s + 1] = p1, p0


def _move_qubit(m, p, s):
def _move_qubit(m: mps.mps, p: int, s: int) -> None:
if m.q2t[p] != s:
# print(f"moving qubit {p} at {m.q2t[p]} to {s}")
for u in range(m.q2t[p], s):
Expand All @@ -27,7 +33,7 @@ def _move_qubit(m, p, s):
_swap_tensors(m, u - 1)


def _apply(m, p, pos=None):
def _apply(m: mps.mps, p: npt.NDArray, pos: Sequence[int] | None = None) -> mps.mps:
assert op.is_operator(p)
n = mps.num_qubits(m)
if pos is None:
Expand All @@ -50,15 +56,15 @@ def _apply(m, p, pos=None):
return m


def apply(m, p, pos=None):
def apply(m: mps.mps, p: npt.NDArray, pos: Sequence[int] | None = None) -> mps.mps:
return _apply(deepcopy(m), p, pos)


def _apply_seq(m, seq):
def _apply_seq(m: mps.mps, seq: Iterable[OPSeqElement]) -> mps.mps:
for p, qubit in seq:
_apply(m, p, qubit)
return m


def apply_seq(m, seq):
def apply_seq(m: mps.mps, seq: Iterable[OPSeqElement]) -> mps.mps:
return _apply_seq(deepcopy(m), seq)
Loading
Loading