diff --git a/examples/dynamic_allocation.py b/examples/dynamic_allocation.py index ddd6352a..96e06eaf 100644 --- a/examples/dynamic_allocation.py +++ b/examples/dynamic_allocation.py @@ -13,36 +13,15 @@ ) context = Context() -mod = Module(context, "dynamic_allocation") -builder = Builder(context) - -# Define module flags -i1 = pyqir.IntType(context, 1) -i32 = pyqir.IntType(context, 32) - -mod.add_flag( - ModuleFlagBehavior.ERROR, - "qir_major_version", - pyqir.const(i32, 1), -) - -mod.add_flag( - ModuleFlagBehavior.MAX, - "qir_minor_version", - pyqir.const(i32, 0), -) - -mod.add_flag( - ModuleFlagBehavior.ERROR, - "dynamic_qubit_management", - pyqir.const(i1, True), -) - -mod.add_flag( - ModuleFlagBehavior.ERROR, - "dynamic_result_management", - pyqir.const(i1, True), +mod = pyqir.qir_module( + context, + "dynamic_allocation", + qir_major_version=1, + qir_minor_version=0, + dynamic_qubit_management=True, + dynamic_result_management=True, ) +builder = Builder(context) # define external calls and type definitions qubit_type = pyqir.qubit_type(context) diff --git a/pyqir/pyqir/__init__.py b/pyqir/pyqir/__init__.py index 0762825b..2f555c07 100644 --- a/pyqir/pyqir/__init__.py +++ b/pyqir/pyqir/__init__.py @@ -35,6 +35,8 @@ Type, Value, const, + dynamic_qubit_management, + dynamic_result_management, entry_point, extract_byte_string, global_byte_string, @@ -45,6 +47,9 @@ qubit, qubit_id, qubit_type, + qir_major_version, + qir_minor_version, + qir_module, required_num_qubits, required_num_results, result, @@ -90,6 +95,8 @@ "Type", "Value", "const", + "dynamic_qubit_management", + "dynamic_result_management", "entry_point", "extract_byte_string", "global_byte_string", @@ -100,6 +107,9 @@ "qubit_id", "qubit_type", "qubit", + "qir_major_version", + "qir_minor_version", + "qir_module", "required_num_qubits", "required_num_results", "result_id", diff --git a/pyqir/pyqir/_native.pyi b/pyqir/pyqir/_native.pyi index edd49de6..0c84966a 100644 --- a/pyqir/pyqir/_native.pyi +++ b/pyqir/pyqir/_native.pyi @@ -778,6 +778,44 @@ def entry_point( """ ... +def qir_major_version(module: Module) -> Optional[int]: + """The QIR major version this module is built for. None if unspecified.""" + ... + +def qir_minor_version(module: Module) -> Optional[int]: + """The QIR minor version this module is built for. None if unspecified.""" + ... + +def dynamic_qubit_management(module: Module) -> Optional[bool]: + """Whether this module supports dynamic qubit management. None if unspecified.""" + ... + +def dynamic_result_management(module: Module) -> Optional[bool]: + """Whether this module supports dynamic result management. None if unspecified.""" + ... + +def qir_module( + context: Context, + name: str, + qir_major_version: int = 1, + qir_minor_version: int = 0, + dynamic_qubit_management: bool = False, + dynamic_result_management: bool = False, +) -> Module: + """ + Creates a module with required QIR module flag metadata + + :param Context context: The parent context. + :param str name: The module name. + :param int qir_major_version: The QIR major version this module is built for. Default 1. + :param int qir_minor_version: The QIR minor version this module is built for. Default 0. + :param bool dynamic_qubit_management: Whether this module supports dynamic qubit management. Default False. + :param bool dynamic_result_management: Whether this module supports dynamic result management. Default False. + :returns: A module with the QIR module flags initialized + :rtype: Module + """ + ... + def extract_byte_string(value: Value) -> Optional[bytes]: """ If the value is a pointer to a constant byte string, extracts it. diff --git a/pyqir/pyqir/_simple.py b/pyqir/pyqir/_simple.py index e928cb4b..256754c9 100644 --- a/pyqir/pyqir/_simple.py +++ b/pyqir/pyqir/_simple.py @@ -45,7 +45,14 @@ def __init__( if context is None: context = Context() - self._module = Module(context, name) + self._module = pyqir.qir_module( + context, + name, + qir_major_version=1, + qir_minor_version=0, + dynamic_qubit_management=False, + dynamic_result_management=False, + ) self._builder = Builder(context) self._num_qubits = num_qubits self._num_results = num_results @@ -53,33 +60,6 @@ def __init__( entry_point = pyqir.entry_point(self._module, "main", num_qubits, num_results) self._builder.insert_at_end(BasicBlock(context, "entry", entry_point)) - i1 = pyqir.IntType(context, 1) - i32 = pyqir.IntType(context, 32) - - self._module.add_flag( - ModuleFlagBehavior.ERROR, - "qir_major_version", - pyqir.const(i32, 1), - ) - - self._module.add_flag( - ModuleFlagBehavior.MAX, - "qir_minor_version", - pyqir.const(i32, 0), - ) - - self._module.add_flag( - ModuleFlagBehavior.ERROR, - "dynamic_qubit_management", - pyqir.const(i1, False), - ) - - self._module.add_flag( - ModuleFlagBehavior.ERROR, - "dynamic_result_management", - pyqir.const(i1, False), - ) - @property def context(self) -> Context: """The LLVM context.""" diff --git a/pyqir/src/python.rs b/pyqir/src/python.rs index 427ccc2e..a594d69a 100644 --- a/pyqir/src/python.rs +++ b/pyqir/src/python.rs @@ -19,10 +19,11 @@ use crate::{ PointerType, StructType, Type, }, values::{ - entry_point, extract_byte_string, global_byte_string, is_entry_point, is_interop_friendly, - qubit, qubit_id, r#const, required_num_qubits, required_num_results, result, result_id, - Attribute, AttributeList, AttributeSet, BasicBlock, Constant, FloatConstant, Function, - IntConstant, Value, + dynamic_qubit_management, dynamic_result_management, entry_point, extract_byte_string, + global_byte_string, is_entry_point, is_interop_friendly, qir_major_version, + qir_minor_version, qir_module, qubit, qubit_id, r#const, required_num_qubits, + required_num_results, result, result_id, Attribute, AttributeList, AttributeSet, + BasicBlock, Constant, FloatConstant, Function, IntConstant, Value, }, }; use pyo3::prelude::*; @@ -61,6 +62,8 @@ fn _native(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_function(wrap_pyfunction!(dynamic_qubit_management, m)?)?; + m.add_function(wrap_pyfunction!(dynamic_result_management, m)?)?; m.add_function(wrap_pyfunction!(entry_point, m)?)?; m.add_function(wrap_pyfunction!(extract_byte_string, m)?)?; m.add_function(wrap_pyfunction!(global_byte_string, m)?)?; @@ -68,6 +71,9 @@ fn _native(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(is_interop_friendly, m)?)?; m.add_function(wrap_pyfunction!(is_qubit_type, m)?)?; m.add_function(wrap_pyfunction!(is_result_type, m)?)?; + m.add_function(wrap_pyfunction!(qir_major_version, m)?)?; + m.add_function(wrap_pyfunction!(qir_minor_version, m)?)?; + m.add_function(wrap_pyfunction!(qir_module, m)?)?; m.add_function(wrap_pyfunction!(qubit_id, m)?)?; m.add_function(wrap_pyfunction!(qubit_type, m)?)?; m.add_function(wrap_pyfunction!(qubit, m)?)?; diff --git a/pyqir/src/values.rs b/pyqir/src/values.rs index ed668f6e..ae345046 100644 --- a/pyqir/src/values.rs +++ b/pyqir/src/values.rs @@ -761,6 +761,79 @@ pub(crate) fn required_num_results(function: PyRef) -> Option { unsafe { values::required_num_results(function.into_super().into_super().as_ptr()) } } +/// Creates a module with required QIR module flag metadata +/// +/// :param Context context: The parent context. +/// :param str name: The module name. +/// :param int qir_major_version: The QIR major version this module is built for. Default 1. +/// :param int qir_minor_version: The QIR minor version this module is built for. Default 0. +/// :param bool dynamic_qubit_management: Whether this module supports dynamic qubit management. Default False. +/// :param bool dynamic_result_management: Whether this module supports dynamic result management. Default False. +/// :rtype: Module +#[pyfunction] +#[pyo3( + text_signature = "(context, name, qir_major_version, qir_minor_version, dynamic_qubit_management, dynamic_result_management)" +)] +pub(crate) fn qir_module( + py: Python, + context: Py, + name: &str, + qir_major_version: Option, + qir_minor_version: Option, + dynamic_qubit_management: Option, + dynamic_result_management: Option, +) -> PyResult { + let module = crate::module::Module::new(py, context, name); + let ptr = module.as_ptr(); + unsafe { + qirlib::module::set_qir_major_version(ptr, qir_major_version.unwrap_or(1)); + } + unsafe { + qirlib::module::set_qir_minor_version(ptr, qir_minor_version.unwrap_or(0)); + } + unsafe { + qirlib::module::set_dynamic_qubit_management( + ptr, + dynamic_qubit_management.unwrap_or(false), + ); + } + unsafe { + qirlib::module::set_dynamic_result_management( + ptr, + dynamic_result_management.unwrap_or(false), + ); + } + Ok(Py::new(py, module)?.to_object(py)) +} + +/// The QIR major version this module is built for. None if unspecified. +#[pyfunction] +#[pyo3(text_signature = "(module)")] +pub(crate) fn qir_major_version(module: PyRef) -> Option { + unsafe { qirlib::module::qir_major_version(module.as_ptr()) } +} + +/// The QIR minor version this module is built for. None if unspecified. +#[pyfunction] +#[pyo3(text_signature = "(module)")] +pub(crate) fn qir_minor_version(module: PyRef) -> Option { + unsafe { qirlib::module::qir_minor_version(module.as_ptr()) } +} + +/// Whether this module supports dynamic qubit management. None if unspecified. +#[pyfunction] +#[pyo3(text_signature = "(module)")] +pub(crate) fn dynamic_qubit_management(module: PyRef) -> Option { + unsafe { qirlib::module::dynamic_qubit_management(module.as_ptr()) } +} + +/// Whether this module supports dynamic result management. None if unspecified. +#[pyfunction] +#[pyo3(text_signature = "(module)")] +pub(crate) fn dynamic_result_management(module: PyRef) -> Option { + unsafe { qirlib::module::dynamic_result_management(module.as_ptr()) } +} + /// Creates a global null-terminated byte string constant in a module. /// /// :param Module module: The parent module. diff --git a/pyqir/tests/test_module_attributes.py b/pyqir/tests/test_module_attributes.py index 1225163e..75a03206 100644 --- a/pyqir/tests/test_module_attributes.py +++ b/pyqir/tests/test_module_attributes.py @@ -89,3 +89,39 @@ def test_add_value_flag_raises_with_wrong_ownership() -> None: mod = pyqir.Module(pyqir.Context(), "") with pytest.raises(ValueError): mod.add_flag(ModuleFlagBehavior.ERROR, "", value) + + +def test_module_qir_major_version() -> None: + assert pyqir.qir_major_version(pyqir.Module(pyqir.Context(), "")) is None + assert pyqir.qir_major_version(pyqir.qir_module(pyqir.Context(), "")) is 1 + mod = pyqir.qir_module(pyqir.Context(), "", 42) + assert pyqir.qir_major_version(mod) == 42 + + +def test_module_qir_minor_version() -> None: + assert pyqir.qir_minor_version(pyqir.Module(pyqir.Context(), "")) is None + assert pyqir.qir_minor_version(pyqir.qir_module(pyqir.Context(), "")) is 0 + mod = pyqir.qir_module(pyqir.Context(), "", 1, 42) + assert pyqir.qir_minor_version(mod) == 42 + + +def test_module_dynamic_qubit_management() -> None: + assert pyqir.dynamic_qubit_management(pyqir.Module(pyqir.Context(), "")) is None + assert ( + pyqir.dynamic_qubit_management(pyqir.qir_module(pyqir.Context(), "")) is False + ) + mod = pyqir.qir_module(pyqir.Context(), "", dynamic_qubit_management=True) + assert pyqir.dynamic_qubit_management(mod) == True + mod = pyqir.qir_module(pyqir.Context(), "", dynamic_qubit_management=False) + assert pyqir.dynamic_qubit_management(mod) == False + + +def test_module_dynamic_result_management() -> None: + assert pyqir.dynamic_result_management(pyqir.Module(pyqir.Context(), "")) is None + assert ( + pyqir.dynamic_result_management(pyqir.qir_module(pyqir.Context(), "")) is False + ) + mod = pyqir.qir_module(pyqir.Context(), "", dynamic_result_management=True) + assert pyqir.dynamic_result_management(mod) == True + mod = pyqir.qir_module(pyqir.Context(), "", dynamic_result_management=False) + assert pyqir.dynamic_result_management(mod) == False diff --git a/qirlib/src/module.rs b/qirlib/src/module.rs index 1da69990..10d20570 100644 --- a/qirlib/src/module.rs +++ b/qirlib/src/module.rs @@ -1,9 +1,18 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use llvm_sys::prelude::{LLVMMetadataRef, LLVMModuleRef}; +use llvm_sys::{ + core::{ + LLVMConstInt, LLVMConstIntGetZExtValue, LLVMGetModuleContext, LLVMGetModuleFlag, + LLVMInt1TypeInContext, LLVMInt32TypeInContext, LLVMMetadataAsValue, LLVMValueAsMetadata, + }, + prelude::{LLVMMetadataRef, LLVMModuleRef}, +}; -use crate::llvm_wrapper::{LLVMRustAddModuleFlag, LLVMRustModFlagBehavior}; +use crate::{ + llvm_wrapper::{LLVMRustAddModuleFlag, LLVMRustModFlagBehavior}, + metadata::extract_constant, +}; pub enum FlagBehavior { Error, @@ -43,6 +52,82 @@ impl From for LLVMRustModFlagBehavior { } } +pub unsafe fn qir_major_version(module: LLVMModuleRef) -> Option { + i32::try_from(get_u64_flag(module, "qir_major_version")?).ok() +} + +pub unsafe fn set_qir_major_version(module: LLVMModuleRef, value: i32) { + let context = LLVMGetModuleContext(module); + let i32ty = LLVMInt32TypeInContext(context); + let const_value = LLVMConstInt(i32ty, value.try_into().unwrap(), 0); + let md = LLVMValueAsMetadata(const_value); + add_flag(module, FlagBehavior::Error, "qir_major_version", md); +} + +pub unsafe fn qir_minor_version(module: LLVMModuleRef) -> Option { + i32::try_from(get_u64_flag(module, "qir_minor_version")?).ok() +} + +pub unsafe fn set_qir_minor_version(module: LLVMModuleRef, value: i32) { + let context = LLVMGetModuleContext(module); + let i32ty = LLVMInt32TypeInContext(context); + let const_value = LLVMConstInt(i32ty, value.try_into().unwrap(), 0); + let md = LLVMValueAsMetadata(const_value); + add_flag(module, FlagBehavior::Max, "qir_minor_version", md); +} + +pub unsafe fn dynamic_qubit_management(module: LLVMModuleRef) -> Option { + get_i1_flag(module, "dynamic_qubit_management") +} + +pub unsafe fn set_dynamic_qubit_management(module: LLVMModuleRef, value: bool) { + let context = LLVMGetModuleContext(module); + let i1ty = LLVMInt1TypeInContext(context); + let const_value = LLVMConstInt(i1ty, u64::from(value), 0); + let md = LLVMValueAsMetadata(const_value); + add_flag(module, FlagBehavior::Error, "dynamic_qubit_management", md); +} + +pub unsafe fn dynamic_result_management(module: LLVMModuleRef) -> Option { + get_i1_flag(module, "dynamic_result_management") +} + +pub unsafe fn set_dynamic_result_management(module: LLVMModuleRef, value: bool) { + let context = LLVMGetModuleContext(module); + let i1ty = LLVMInt1TypeInContext(context); + let const_value = LLVMConstInt(i1ty, u64::from(value), 0); + let md = LLVMValueAsMetadata(const_value); + add_flag(module, FlagBehavior::Error, "dynamic_result_management", md); +} + +unsafe fn get_u64_flag(module: LLVMModuleRef, id: &str) -> Option { + if let Some(flag) = get_flag(module, id) { + if let Some(constant) = + extract_constant(LLVMMetadataAsValue(LLVMGetModuleContext(module), flag)) + { + let value = LLVMConstIntGetZExtValue(constant); + return Some(value); + } + } + None +} + +unsafe fn get_i1_flag(module: LLVMModuleRef, id: &str) -> Option { + if let Some(value) = get_u64_flag(module, id) { + return Some(value != 0); + } + None +} + +pub unsafe fn get_flag(module: LLVMModuleRef, id: &str) -> Option { + let flag = LLVMGetModuleFlag(module, id.as_ptr().cast(), id.len()); + + if flag.is_null() { + return None; + } + Some(flag) +} + pub unsafe fn add_flag( module: LLVMModuleRef, behavior: FlagBehavior,