Skip to content

Commit

Permalink
Add API for required module metadata (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
idavis authored Jan 23, 2023
1 parent 43bd7f4 commit b49f671
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 63 deletions.
37 changes: 8 additions & 29 deletions examples/dynamic_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,15 @@
)

context = Context()
mod = Module(context, "dynamic_allocation")
builder = Builder(context)

# Define module flags
i1 = pyqir.IntType(context, 1)
i32 = pyqir.IntType(context, 32)

mod.add_flag(
ModuleFlagBehavior.ERROR,
"qir_major_version",
pyqir.const(i32, 1),
)

mod.add_flag(
ModuleFlagBehavior.MAX,
"qir_minor_version",
pyqir.const(i32, 0),
)

mod.add_flag(
ModuleFlagBehavior.ERROR,
"dynamic_qubit_management",
pyqir.const(i1, True),
)

mod.add_flag(
ModuleFlagBehavior.ERROR,
"dynamic_result_management",
pyqir.const(i1, True),
mod = pyqir.qir_module(
context,
"dynamic_allocation",
qir_major_version=1,
qir_minor_version=0,
dynamic_qubit_management=True,
dynamic_result_management=True,
)
builder = Builder(context)

# define external calls and type definitions
qubit_type = pyqir.qubit_type(context)
Expand Down
10 changes: 10 additions & 0 deletions pyqir/pyqir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
Type,
Value,
const,
dynamic_qubit_management,
dynamic_result_management,
entry_point,
extract_byte_string,
global_byte_string,
Expand All @@ -45,6 +47,9 @@
qubit,
qubit_id,
qubit_type,
qir_major_version,
qir_minor_version,
qir_module,
required_num_qubits,
required_num_results,
result,
Expand Down Expand Up @@ -90,6 +95,8 @@
"Type",
"Value",
"const",
"dynamic_qubit_management",
"dynamic_result_management",
"entry_point",
"extract_byte_string",
"global_byte_string",
Expand All @@ -100,6 +107,9 @@
"qubit_id",
"qubit_type",
"qubit",
"qir_major_version",
"qir_minor_version",
"qir_module",
"required_num_qubits",
"required_num_results",
"result_id",
Expand Down
38 changes: 38 additions & 0 deletions pyqir/pyqir/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,44 @@ def entry_point(
"""
...

def qir_major_version(module: Module) -> Optional[int]:
"""The QIR major version this module is built for. None if unspecified."""
...

def qir_minor_version(module: Module) -> Optional[int]:
"""The QIR minor version this module is built for. None if unspecified."""
...

def dynamic_qubit_management(module: Module) -> Optional[bool]:
"""Whether this module supports dynamic qubit management. None if unspecified."""
...

def dynamic_result_management(module: Module) -> Optional[bool]:
"""Whether this module supports dynamic result management. None if unspecified."""
...

def qir_module(
context: Context,
name: str,
qir_major_version: int = 1,
qir_minor_version: int = 0,
dynamic_qubit_management: bool = False,
dynamic_result_management: bool = False,
) -> Module:
"""
Creates a module with required QIR module flag metadata
:param Context context: The parent context.
:param str name: The module name.
:param int qir_major_version: The QIR major version this module is built for. Default 1.
:param int qir_minor_version: The QIR minor version this module is built for. Default 0.
:param bool dynamic_qubit_management: Whether this module supports dynamic qubit management. Default False.
:param bool dynamic_result_management: Whether this module supports dynamic result management. Default False.
:returns: A module with the QIR module flags initialized
:rtype: Module
"""
...

def extract_byte_string(value: Value) -> Optional[bytes]:
"""
If the value is a pointer to a constant byte string, extracts it.
Expand Down
36 changes: 8 additions & 28 deletions pyqir/pyqir/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,21 @@ def __init__(
if context is None:
context = Context()

self._module = Module(context, name)
self._module = pyqir.qir_module(
context,
name,
qir_major_version=1,
qir_minor_version=0,
dynamic_qubit_management=False,
dynamic_result_management=False,
)
self._builder = Builder(context)
self._num_qubits = num_qubits
self._num_results = num_results

entry_point = pyqir.entry_point(self._module, "main", num_qubits, num_results)
self._builder.insert_at_end(BasicBlock(context, "entry", entry_point))

i1 = pyqir.IntType(context, 1)
i32 = pyqir.IntType(context, 32)

self._module.add_flag(
ModuleFlagBehavior.ERROR,
"qir_major_version",
pyqir.const(i32, 1),
)

self._module.add_flag(
ModuleFlagBehavior.MAX,
"qir_minor_version",
pyqir.const(i32, 0),
)

self._module.add_flag(
ModuleFlagBehavior.ERROR,
"dynamic_qubit_management",
pyqir.const(i1, False),
)

self._module.add_flag(
ModuleFlagBehavior.ERROR,
"dynamic_result_management",
pyqir.const(i1, False),
)

@property
def context(self) -> Context:
"""The LLVM context."""
Expand Down
14 changes: 10 additions & 4 deletions pyqir/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ use crate::{
PointerType, StructType, Type,
},
values::{
entry_point, extract_byte_string, global_byte_string, is_entry_point, is_interop_friendly,
qubit, qubit_id, r#const, required_num_qubits, required_num_results, result, result_id,
Attribute, AttributeList, AttributeSet, BasicBlock, Constant, FloatConstant, Function,
IntConstant, Value,
dynamic_qubit_management, dynamic_result_management, entry_point, extract_byte_string,
global_byte_string, is_entry_point, is_interop_friendly, qir_major_version,
qir_minor_version, qir_module, qubit, qubit_id, r#const, required_num_qubits,
required_num_results, result, result_id, Attribute, AttributeList, AttributeSet,
BasicBlock, Constant, FloatConstant, Function, IntConstant, Value,
},
};
use pyo3::prelude::*;
Expand Down Expand Up @@ -61,13 +62,18 @@ fn _native(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Switch>()?;
m.add_class::<Type>()?;
m.add_class::<Value>()?;
m.add_function(wrap_pyfunction!(dynamic_qubit_management, m)?)?;
m.add_function(wrap_pyfunction!(dynamic_result_management, m)?)?;
m.add_function(wrap_pyfunction!(entry_point, m)?)?;
m.add_function(wrap_pyfunction!(extract_byte_string, m)?)?;
m.add_function(wrap_pyfunction!(global_byte_string, m)?)?;
m.add_function(wrap_pyfunction!(is_entry_point, m)?)?;
m.add_function(wrap_pyfunction!(is_interop_friendly, m)?)?;
m.add_function(wrap_pyfunction!(is_qubit_type, m)?)?;
m.add_function(wrap_pyfunction!(is_result_type, m)?)?;
m.add_function(wrap_pyfunction!(qir_major_version, m)?)?;
m.add_function(wrap_pyfunction!(qir_minor_version, m)?)?;
m.add_function(wrap_pyfunction!(qir_module, m)?)?;
m.add_function(wrap_pyfunction!(qubit_id, m)?)?;
m.add_function(wrap_pyfunction!(qubit_type, m)?)?;
m.add_function(wrap_pyfunction!(qubit, m)?)?;
Expand Down
73 changes: 73 additions & 0 deletions pyqir/src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,79 @@ pub(crate) fn required_num_results(function: PyRef<Function>) -> Option<u64> {
unsafe { values::required_num_results(function.into_super().into_super().as_ptr()) }
}

/// Creates a module with required QIR module flag metadata
///
/// :param Context context: The parent context.
/// :param str name: The module name.
/// :param int qir_major_version: The QIR major version this module is built for. Default 1.
/// :param int qir_minor_version: The QIR minor version this module is built for. Default 0.
/// :param bool dynamic_qubit_management: Whether this module supports dynamic qubit management. Default False.
/// :param bool dynamic_result_management: Whether this module supports dynamic result management. Default False.
/// :rtype: Module
#[pyfunction]
#[pyo3(
text_signature = "(context, name, qir_major_version, qir_minor_version, dynamic_qubit_management, dynamic_result_management)"
)]
pub(crate) fn qir_module(
py: Python,
context: Py<Context>,
name: &str,
qir_major_version: Option<i32>,
qir_minor_version: Option<i32>,
dynamic_qubit_management: Option<bool>,
dynamic_result_management: Option<bool>,
) -> PyResult<PyObject> {
let module = crate::module::Module::new(py, context, name);
let ptr = module.as_ptr();
unsafe {
qirlib::module::set_qir_major_version(ptr, qir_major_version.unwrap_or(1));
}
unsafe {
qirlib::module::set_qir_minor_version(ptr, qir_minor_version.unwrap_or(0));
}
unsafe {
qirlib::module::set_dynamic_qubit_management(
ptr,
dynamic_qubit_management.unwrap_or(false),
);
}
unsafe {
qirlib::module::set_dynamic_result_management(
ptr,
dynamic_result_management.unwrap_or(false),
);
}
Ok(Py::new(py, module)?.to_object(py))
}

/// The QIR major version this module is built for. None if unspecified.
#[pyfunction]
#[pyo3(text_signature = "(module)")]
pub(crate) fn qir_major_version(module: PyRef<Module>) -> Option<i32> {
unsafe { qirlib::module::qir_major_version(module.as_ptr()) }
}

/// The QIR minor version this module is built for. None if unspecified.
#[pyfunction]
#[pyo3(text_signature = "(module)")]
pub(crate) fn qir_minor_version(module: PyRef<Module>) -> Option<i32> {
unsafe { qirlib::module::qir_minor_version(module.as_ptr()) }
}

/// Whether this module supports dynamic qubit management. None if unspecified.
#[pyfunction]
#[pyo3(text_signature = "(module)")]
pub(crate) fn dynamic_qubit_management(module: PyRef<Module>) -> Option<bool> {
unsafe { qirlib::module::dynamic_qubit_management(module.as_ptr()) }
}

/// Whether this module supports dynamic result management. None if unspecified.
#[pyfunction]
#[pyo3(text_signature = "(module)")]
pub(crate) fn dynamic_result_management(module: PyRef<Module>) -> Option<bool> {
unsafe { qirlib::module::dynamic_result_management(module.as_ptr()) }
}

/// Creates a global null-terminated byte string constant in a module.
///
/// :param Module module: The parent module.
Expand Down
36 changes: 36 additions & 0 deletions pyqir/tests/test_module_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,39 @@ def test_add_value_flag_raises_with_wrong_ownership() -> None:
mod = pyqir.Module(pyqir.Context(), "")
with pytest.raises(ValueError):
mod.add_flag(ModuleFlagBehavior.ERROR, "", value)


def test_module_qir_major_version() -> None:
assert pyqir.qir_major_version(pyqir.Module(pyqir.Context(), "")) is None
assert pyqir.qir_major_version(pyqir.qir_module(pyqir.Context(), "")) is 1
mod = pyqir.qir_module(pyqir.Context(), "", 42)
assert pyqir.qir_major_version(mod) == 42


def test_module_qir_minor_version() -> None:
assert pyqir.qir_minor_version(pyqir.Module(pyqir.Context(), "")) is None
assert pyqir.qir_minor_version(pyqir.qir_module(pyqir.Context(), "")) is 0
mod = pyqir.qir_module(pyqir.Context(), "", 1, 42)
assert pyqir.qir_minor_version(mod) == 42


def test_module_dynamic_qubit_management() -> None:
assert pyqir.dynamic_qubit_management(pyqir.Module(pyqir.Context(), "")) is None
assert (
pyqir.dynamic_qubit_management(pyqir.qir_module(pyqir.Context(), "")) is False
)
mod = pyqir.qir_module(pyqir.Context(), "", dynamic_qubit_management=True)
assert pyqir.dynamic_qubit_management(mod) == True
mod = pyqir.qir_module(pyqir.Context(), "", dynamic_qubit_management=False)
assert pyqir.dynamic_qubit_management(mod) == False


def test_module_dynamic_result_management() -> None:
assert pyqir.dynamic_result_management(pyqir.Module(pyqir.Context(), "")) is None
assert (
pyqir.dynamic_result_management(pyqir.qir_module(pyqir.Context(), "")) is False
)
mod = pyqir.qir_module(pyqir.Context(), "", dynamic_result_management=True)
assert pyqir.dynamic_result_management(mod) == True
mod = pyqir.qir_module(pyqir.Context(), "", dynamic_result_management=False)
assert pyqir.dynamic_result_management(mod) == False
Loading

0 comments on commit b49f671

Please sign in to comment.