From feeef97c4e78b52929b78a70412730710565fc7f Mon Sep 17 00:00:00 2001 From: Ian Davis Date: Mon, 16 Sep 2024 13:10:38 -0500 Subject: [PATCH] Add link function to Module (#293) --- pyqir/pyqir/_native.pyi | 9 +++ pyqir/src/module.rs | 35 +++++++++- pyqir/tests/5_bit_random_number.ll | 44 ++++++++++++ pyqir/tests/combined_module.ll | 56 ++++++++++++++++ pyqir/tests/profile_v1.0_compat.ll | 34 ++++++++++ pyqir/tests/profile_v1.1_compat.ll | 34 ++++++++++ pyqir/tests/profile_v2.0_compat.ll | 34 ++++++++++ pyqir/tests/random_bit.ll | 34 ++++++++++ pyqir/tests/test_module_linking.py | 104 +++++++++++++++++++++++++++++ qirlib/src/context.rs | 30 +++++++++ qirlib/src/lib.rs | 2 + 11 files changed, 415 insertions(+), 1 deletion(-) create mode 100644 pyqir/tests/5_bit_random_number.ll create mode 100644 pyqir/tests/combined_module.ll create mode 100644 pyqir/tests/profile_v1.0_compat.ll create mode 100644 pyqir/tests/profile_v1.1_compat.ll create mode 100644 pyqir/tests/profile_v2.0_compat.ll create mode 100644 pyqir/tests/random_bit.ll create mode 100644 pyqir/tests/test_module_linking.py create mode 100644 qirlib/src/context.rs diff --git a/pyqir/pyqir/_native.pyi b/pyqir/pyqir/_native.pyi index 0723dd50..4b151057 100644 --- a/pyqir/pyqir/_native.pyi +++ b/pyqir/pyqir/_native.pyi @@ -681,6 +681,15 @@ class Module: """Converts this module into an LLVM IR string.""" ... + def link(self, other: Module) -> None: + """ + Link the supplied module into the current module. + Destroys the supplied module. + + :raises: An error if linking failed. + """ + ... + class ModuleFlagBehavior(Enum): """Module flag behavior choices""" diff --git a/pyqir/src/module.rs b/pyqir/src/module.rs index 8ee46652..1d6d2c53 100644 --- a/pyqir/src/module.rs +++ b/pyqir/src/module.rs @@ -9,6 +9,7 @@ use crate::{ metadata::Metadata, values::{Constant, Owner, Value}, }; +use core::mem::forget; use core::slice; #[allow(clippy::wildcard_imports, deprecated)] use llvm_sys::{ @@ -17,10 +18,11 @@ use llvm_sys::{ bit_writer::LLVMWriteBitcodeToMemoryBuffer, core::*, ir_reader::LLVMParseIRInContext, + linker::LLVMLinkModules2, LLVMLinkage, LLVMModule, }; use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyBytes}; -use qirlib::module::FlagBehavior; +use qirlib::{context::set_diagnostic_handler, module::FlagBehavior}; use std::{ collections::hash_map::DefaultHasher, ffi::CString, @@ -263,6 +265,37 @@ impl Module { .to_string() } } + + /// Link the supplied module into the current module. + /// Destroys the supplied module. + /// + /// :raises: An error if linking failed. + pub fn link(&self, other: Py, py: Python) -> PyResult<()> { + let context = self.context.borrow(py).as_ptr(); + if context != other.borrow(py).context.borrow(py).as_ptr() { + return Err(PyValueError::new_err( + "Cannot link modules from different contexts. Modules are untouched.".to_string(), + )); + } + unsafe { + let mut c_char_output: *mut ::core::ffi::c_char = ptr::null_mut(); + let output = ::core::ptr::from_mut::<*mut ::core::ffi::c_char>(&mut c_char_output) + .cast::<*mut ::core::ffi::c_void>() + .cast::<::core::ffi::c_void>(); + + set_diagnostic_handler(context, output); + let result = LLVMLinkModules2(self.module.as_ptr(), other.borrow(py).module.as_ptr()); + // `forget` the other module. LLVM has destroyed it + // and we'll get a segfault if we drop it. + forget(other); + if result == 0 { + Ok(()) + } else { + let error = Message::from_raw(c_char_output); + return Err(PyValueError::new_err(error.to_str().unwrap().to_string())); + } + } + } } impl Deref for Module { diff --git a/pyqir/tests/5_bit_random_number.ll b/pyqir/tests/5_bit_random_number.ll new file mode 100644 index 00000000..9c0c7b8f --- /dev/null +++ b/pyqir/tests/5_bit_random_number.ll @@ -0,0 +1,44 @@ +; ModuleID = '5_bit_random_number' +%Result = type opaque +%Qubit = type opaque + +define void @five_bit_random_number() #0 { +block_0: + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 2 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Result* inttoptr (i64 3 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 4 to %Qubit*), %Result* inttoptr (i64 4 to %Result*)) + call void @__quantum__rt__array_record_output(i64 5, i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 2 to %Result*), i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 3 to %Result*), i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 4 to %Result*), i8* null) + ret void +} + +declare void @__quantum__qis__h__body(%Qubit*) + +declare void @__quantum__rt__array_record_output(i64, i8*) + +declare void @__quantum__rt__result_record_output(%Result*, i8*) + +declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="5" "required_num_results"="5" } +attributes #1 = { "irreversible" } + +; module flags + +!llvm.module.flags = !{!0, !1, !2, !3} + +!0 = !{i32 1, !"qir_major_version", i32 1} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} \ No newline at end of file diff --git a/pyqir/tests/combined_module.ll b/pyqir/tests/combined_module.ll new file mode 100644 index 00000000..ebadbb78 --- /dev/null +++ b/pyqir/tests/combined_module.ll @@ -0,0 +1,56 @@ + +%Qubit = type opaque +%Result = type opaque + +define void @random_bit() #0 { +block_0: + call void @__quantum__qis__h__body(%Qubit* null) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* null) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* null) + call void @__quantum__rt__result_record_output(%Result* null, i8* null) + ret void +} + +declare void @__quantum__qis__h__body(%Qubit*) + +declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*) + +declare void @__quantum__rt__result_record_output(%Result*, i8*) + +declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + +define void @five_bit_random_number() #2 { +block_0: + call void @__quantum__qis__h__body(%Qubit* null) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* null, %Result* null) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 2 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Result* inttoptr (i64 3 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 4 to %Qubit*), %Result* inttoptr (i64 4 to %Result*)) + call void @__quantum__rt__array_record_output(i64 5, i8* null) + call void @__quantum__rt__result_record_output(%Result* null, i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 2 to %Result*), i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 3 to %Result*), i8* null) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 4 to %Result*), i8* null) + ret void +} + +declare void @__quantum__rt__array_record_output(i64, i8*) + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" } +attributes #1 = { "irreversible" } +attributes #2 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="5" "required_num_results"="5" } + +!llvm.module.flags = !{!0, !1, !2, !3} + +!0 = !{i32 1, !"qir_major_version", i32 1} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} diff --git a/pyqir/tests/profile_v1.0_compat.ll b/pyqir/tests/profile_v1.0_compat.ll new file mode 100644 index 00000000..9a1ef1c2 --- /dev/null +++ b/pyqir/tests/profile_v1.0_compat.ll @@ -0,0 +1,34 @@ +; ModuleID = 'OneDotZero' +%Result = type opaque +%Qubit = type opaque + +define void @OneDotZero() #0 { +block_0: + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null) + ret void +} + +declare void @__quantum__qis__h__body(%Qubit*) + +declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*) + +declare void @__quantum__rt__result_record_output(%Result*, i8*) + +declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" } +attributes #1 = { "irreversible" } + +; module flags + +!llvm.module.flags = !{!0, !1, !2, !3} + +!0 = !{i32 1, !"qir_major_version", i32 1} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} \ No newline at end of file diff --git a/pyqir/tests/profile_v1.1_compat.ll b/pyqir/tests/profile_v1.1_compat.ll new file mode 100644 index 00000000..1bf726d1 --- /dev/null +++ b/pyqir/tests/profile_v1.1_compat.ll @@ -0,0 +1,34 @@ +; ModuleID = 'OneDotOne' +%Result = type opaque +%Qubit = type opaque + +define void @OneDotOne() #0 { +block_0: + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null) + ret void +} + +declare void @__quantum__qis__h__body(%Qubit*) + +declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*) + +declare void @__quantum__rt__result_record_output(%Result*, i8*) + +declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" } +attributes #1 = { "irreversible" } + +; module flags + +!llvm.module.flags = !{!0, !1, !2, !3} + +!0 = !{i32 1, !"qir_major_version", i32 1} +!1 = !{i32 7, !"qir_minor_version", i32 1} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} \ No newline at end of file diff --git a/pyqir/tests/profile_v2.0_compat.ll b/pyqir/tests/profile_v2.0_compat.ll new file mode 100644 index 00000000..c51531ca --- /dev/null +++ b/pyqir/tests/profile_v2.0_compat.ll @@ -0,0 +1,34 @@ +; ModuleID = 'TwoDotZero' +%Result = type opaque +%Qubit = type opaque + +define void @TwoDotZero() #0 { +block_0: + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null) + ret void +} + +declare void @__quantum__qis__h__body(%Qubit*) + +declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*) + +declare void @__quantum__rt__result_record_output(%Result*, i8*) + +declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" } +attributes #1 = { "irreversible" } + +; module flags + +!llvm.module.flags = !{!0, !1, !2, !3} + +!0 = !{i32 1, !"qir_major_version", i32 2} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} \ No newline at end of file diff --git a/pyqir/tests/random_bit.ll b/pyqir/tests/random_bit.ll new file mode 100644 index 00000000..7c1cf384 --- /dev/null +++ b/pyqir/tests/random_bit.ll @@ -0,0 +1,34 @@ +; ModuleID = 'random_bit' +%Result = type opaque +%Qubit = type opaque + +define void @random_bit() #0 { +block_0: + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null) + ret void +} + +declare void @__quantum__qis__h__body(%Qubit*) + +declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*) + +declare void @__quantum__rt__result_record_output(%Result*, i8*) + +declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" } +attributes #1 = { "irreversible" } + +; module flags + +!llvm.module.flags = !{!0, !1, !2, !3} + +!0 = !{i32 1, !"qir_major_version", i32 1} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} \ No newline at end of file diff --git a/pyqir/tests/test_module_linking.py b/pyqir/tests/test_module_linking.py new file mode 100644 index 00000000..10c5d94c --- /dev/null +++ b/pyqir/tests/test_module_linking.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path + +import pyqir +import pytest + +current_file_path = Path(__file__) +# Get the directory of the current file +current_dir = current_file_path.parent + +from pyqir import ( + Context, + Module, +) + + +def read_file(file_name: str) -> str: + return Path(current_dir / file_name).read_text(encoding="utf-8") + + +def get_int_flag_value(module: Module, flag_name: str) -> int: + flag = module.get_flag(flag_name) + assert flag is not None + assert isinstance(flag, pyqir.ConstantAsMetadata) + assert isinstance(flag.value, pyqir.IntConstant) + return flag.value.value + + +def test_link_modules_with_same_context() -> None: + context = Context() + ir = read_file("random_bit.ll") + dest = Module.from_ir(context, ir) + ir = read_file("5_bit_random_number.ll") + src = Module.from_ir(context, ir) + dest.link(src) + assert dest.verify() is None + actual_ir = str(dest) + expected_ir = str(read_file("combined_module.ll")) + assert actual_ir == expected_ir + + +def test_link_modules_with_different_contexts() -> None: + ir = read_file("random_bit.ll") + dest = Module.from_ir(Context(), ir) + ir = read_file("5_bit_random_number.ll") + src = Module.from_ir(Context(), ir) + with pytest.raises(ValueError) as ex: + dest.link(src) + assert ( + str(ex.value) + == "Cannot link modules from different contexts. Modules are untouched." + ) + + +def test_link_module_with_src_minor_version_less() -> None: + context = Context() + ir = read_file("profile_v1.0_compat.ll") + dest = Module.from_ir(context, ir) + ir = read_file("profile_v1.1_compat.ll") + src = Module.from_ir(context, ir) + dest.link(src) + assert get_int_flag_value(dest, "qir_minor_version") == 1 + + +def test_link_module_with_src_minor_version_greater() -> None: + context = Context() + ir = read_file("profile_v1.1_compat.ll") + dest = Module.from_ir(context, ir) + ir = read_file("profile_v1.0_compat.ll") + src = Module.from_ir(context, ir) + dest.link(src) + assert get_int_flag_value(dest, "qir_minor_version") == 1 + + +def test_link_module_with_src_major_version_less() -> None: + context = Context() + ir = read_file("profile_v2.0_compat.ll") + dest = Module.from_ir(context, ir) + ir = read_file("profile_v1.0_compat.ll") + src = Module.from_ir(context, ir) + with pytest.raises(ValueError) as ex: + dest.link(src) + + assert ( + "linking module flags 'qir_major_version': IDs have conflicting values" + in str(ex) + ) + + +def test_link_module_with_src_major_version_greater() -> None: + context = Context() + ir = read_file("profile_v1.0_compat.ll") + dest = Module.from_ir(context, ir) + ir = read_file("profile_v2.0_compat.ll") + src = Module.from_ir(context, ir) + with pytest.raises(ValueError) as ex: + dest.link(src) + + assert ( + "linking module flags 'qir_major_version': IDs have conflicting values" + in str(ex) + ) diff --git a/qirlib/src/context.rs b/qirlib/src/context.rs new file mode 100644 index 00000000..96b915c6 --- /dev/null +++ b/qirlib/src/context.rs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use llvm_sys::{ + core::{LLVMContextSetDiagnosticHandler, LLVMGetDiagInfoDescription, LLVMGetDiagInfoSeverity}, + prelude::{LLVMContextRef, LLVMDiagnosticInfoRef}, + LLVMDiagnosticSeverity, +}; + +pub unsafe fn set_diagnostic_handler( + context: LLVMContextRef, + output_ptr: *mut ::core::ffi::c_void, +) { + unsafe { LLVMContextSetDiagnosticHandler(context, Some(diagnostic_handler), output_ptr) }; +} + +pub(crate) extern "C" fn diagnostic_handler( + diagnostic_info: LLVMDiagnosticInfoRef, + output: *mut ::core::ffi::c_void, +) { + unsafe { + let severity = LLVMGetDiagInfoSeverity(diagnostic_info); + if severity == LLVMDiagnosticSeverity::LLVMDSError { + let c_char_output = output + .cast::<*mut ::core::ffi::c_void>() + .cast::<*mut ::core::ffi::c_char>(); + *c_char_output = LLVMGetDiagInfoDescription(diagnostic_info); + } + } +} diff --git a/qirlib/src/lib.rs b/qirlib/src/lib.rs index 5711e787..038670f9 100644 --- a/qirlib/src/lib.rs +++ b/qirlib/src/lib.rs @@ -20,6 +20,8 @@ extern crate llvm_sys_140 as llvm_sys; #[cfg(not(feature = "no-llvm-linking"))] pub mod builder; #[cfg(not(feature = "no-llvm-linking"))] +pub mod context; +#[cfg(not(feature = "no-llvm-linking"))] pub(crate) mod llvm_wrapper; #[cfg(not(feature = "no-llvm-linking"))] pub mod metadata;