diff --git a/eng/psakefile.ps1 b/eng/psakefile.ps1 index 62a67c06..5b410868 100644 --- a/eng/psakefile.ps1 +++ b/eng/psakefile.ps1 @@ -109,7 +109,7 @@ task check-environment { } Assert ((Test-InVirtualEnvironment) -eq $true) ($env_message -Join ' ') - exec { & $Python -m pip install pip~=23.1 } + exec { & $Python -m pip install pip~=23.3 } } task init -depends check-environment { diff --git a/eng/utils.ps1 b/eng/utils.ps1 index 7cd10775..baf26169 100644 --- a/eng/utils.ps1 +++ b/eng/utils.ps1 @@ -174,11 +174,21 @@ function Test-AllowedToDownloadLlvm { } function Test-InCondaEnvironment { - (Test-Path env:\CONDA_PREFIX) + $found = (Test-Path env:\CONDA_PREFIX) + if ($found) { + $condaPrefix = $env:CONDA_PREFIX + Write-BuildLog "Found conda environment: $condaPrefix" + } + $found } function Test-InVenvEnvironment { - (Test-Path env:\VIRTUAL_ENV) + $found = (Test-Path env:\VIRTUAL_ENV) + if ($found) { + $venv = $env:VIRTUAL_ENV + Write-BuildLog "Found venv environment: $venv" + } + $found } function Test-InVirtualEnvironment { @@ -301,5 +311,5 @@ function install-llvm { if ($clear_cache_var) { Remove-Item -Path Env:QIRLIB_CACHE_DIR } - } + } } diff --git a/pyqir/pyqir/__init__.py b/pyqir/pyqir/__init__.py index 22e3d471..1906128b 100644 --- a/pyqir/pyqir/__init__.py +++ b/pyqir/pyqir/__init__.py @@ -59,6 +59,7 @@ from pyqir._simple import SimpleModule from pyqir._entry_point import entry_point from pyqir._basicqis import BasicQisBuilder +from pyqir._constants import ATTR_FUNCTION_INDEX, ATTR_RETURN_INDEX __all__ = [ "ArrayType", @@ -117,4 +118,6 @@ "result_id", "result_type", "result", + "ATTR_FUNCTION_INDEX", + "ATTR_RETURN_INDEX", ] diff --git a/pyqir/pyqir/_constants.py b/pyqir/pyqir/_constants.py new file mode 100644 index 00000000..8fd17226 --- /dev/null +++ b/pyqir/pyqir/_constants.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +ATTR_RETURN_INDEX = 0 +ATTR_FUNCTION_INDEX = 4294967295 # -1 u32 diff --git a/pyqir/pyqir/_native.pyi b/pyqir/pyqir/_native.pyi index 42be7609..f091e954 100644 --- a/pyqir/pyqir/_native.pyi +++ b/pyqir/pyqir/_native.pyi @@ -2,7 +2,16 @@ # Licensed under the MIT License. from enum import Enum -from typing import Callable, List, Optional, Sequence, Tuple, Union +from typing import ( + Callable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, +) class ArrayType(Type): """An array type.""" @@ -19,6 +28,10 @@ class ArrayType(Type): class Attribute: """An attribute.""" + @property + def string_kind(self) -> str: + """The kind of this attribute as a string.""" + ... @property def string_value(self) -> Optional[str]: """The value of this attribute as a string, or `None` if this is not a string attribute.""" @@ -44,7 +57,13 @@ class AttributeList: """The attributes for the function itself.""" ... -class AttributeSet: +class AttributeIterator(Iterator[Attribute]): + """An iterator of attributes for a specific part of a function.""" + + def __iter__(self) -> Iterator[Attribute]: ... + def __next__(self) -> Attribute: ... + +class AttributeSet(Iterable[Attribute]): """A set of attributes for a specific part of a function.""" def __contains__(self, item: str) -> bool: @@ -63,6 +82,7 @@ class AttributeSet: :returns: The attribute. """ ... + def __iter__(self) -> Iterator[Attribute]: ... class BasicBlock(Value): """A basic block.""" @@ -1223,7 +1243,10 @@ def if_result( ... def add_string_attribute( - function: Function, kind: str, value: Optional[str] = None + function: Function, + kind: str, + value: Optional[str] = None, + index: Optional[int] = None, ) -> bool: """ Adds a string attribute to the given function. @@ -1231,5 +1254,6 @@ def add_string_attribute( :param function: The function. :param key: The attribute key. :param value: The attribute value. + :param index: The optional attribute index, defaults to the function index. """ ... diff --git a/pyqir/src/values.rs b/pyqir/src/values.rs index 39e6c3cd..8fcc05f3 100644 --- a/pyqir/src/values.rs +++ b/pyqir/src/values.rs @@ -23,7 +23,7 @@ use pyo3::{ types::{PyBytes, PyLong, PyString}, PyRef, }; -use qirlib::values; +use qirlib::values::{self, get_string_attribute_kind, get_string_attribute_value}; use std::{ borrow::Borrow, collections::hash_map::DefaultHasher, @@ -33,6 +33,7 @@ use std::{ ops::Deref, ptr::NonNull, slice, str, + vec::IntoIter, }; /// A value. @@ -522,21 +523,20 @@ pub(crate) struct Attribute(LLVMAttributeRef); #[pymethods] impl Attribute { + /// The id of this attribute as a string. + /// + /// :type: str + #[getter] + fn string_kind(&self) -> String { + unsafe { get_string_attribute_kind(self.0) } + } + /// The value of this attribute as a string, or `None` if this is not a string attribute. /// /// :type: typing.Optional[str] #[getter] - fn string_value(&self) -> Option<&str> { - unsafe { - if LLVMIsStringAttribute(self.0) == 0 { - None - } else { - let mut len = 0; - let value = LLVMGetStringAttributeValue(self.0, &mut len).cast(); - let value = slice::from_raw_parts(value, len.try_into().unwrap()); - Some(str::from_utf8(value).unwrap()) - } - } + fn string_value(&self) -> Option { + unsafe { get_string_attribute_value(self.0) } } } @@ -588,6 +588,24 @@ pub(crate) struct AttributeSet { index: LLVMAttributeIndex, } +/// An iterator of attributes for a specific part of a function. +#[pyclass] +struct AttributeIterator { + iter: IntoIter>, +} + +#[pymethods] +impl AttributeIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + // Returning `None` from `__next__` indicates that that there are no further items. + // and maps to StopIteration + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option> { + slf.iter.next() + } +} + #[pymethods] impl AttributeSet { /// Tests if an attribute is a member of the set. @@ -622,6 +640,23 @@ impl AttributeSet { Ok(Attribute(attr)) } } + + fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let function = slf.function.borrow(slf.py()).into_super().into_super(); + + unsafe { + let attrs = qirlib::values::get_attributes(function.as_ptr(), slf.index); + let items = attrs + .into_iter() + .map(|a| Py::new(slf.py(), Attribute(a)).expect("msg")); + Py::new( + slf.py(), + AttributeIterator { + iter: items.collect::>>().into_iter(), + }, + ) + } + } } #[derive(FromPyObject)] @@ -863,12 +898,14 @@ pub(crate) fn extract_byte_string<'py>(py: Python<'py>, value: &Value) -> Option // :param function: The function. // :param kind: The attribute kind. // :param value: The attribute value. +// :param index: The optional attribute index, defaults to the function index. #[pyfunction] -#[pyo3(text_signature = "(function, key, value)")] +#[pyo3(text_signature = "(function, key, value, index)")] pub(crate) fn add_string_attribute<'py>( function: PyRef, key: &'py PyString, value: Option<&'py PyString>, + index: Option, ) { let function = function.into_super().into_super().as_ptr(); let key = key.to_string_lossy(); @@ -881,6 +918,7 @@ pub(crate) fn add_string_attribute<'py>( Some(ref x) => x.as_bytes(), None => &[], }, + index.unwrap_or(LLVMAttributeFunctionIndex), ); } } diff --git a/pyqir/tests/test_string_attributes.py b/pyqir/tests/test_string_attributes.py index a2f68926..c38a71b3 100644 --- a/pyqir/tests/test_string_attributes.py +++ b/pyqir/tests/test_string_attributes.py @@ -1,10 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from typing import List import pyqir from pyqir import ( + Attribute, + ATTR_FUNCTION_INDEX, + ATTR_RETURN_INDEX, + Builder, IntType, - ModuleFlagBehavior, Module, Context, add_string_attribute, @@ -41,9 +45,9 @@ def test_round_trip_serialize_parse() -> None: function = Function(FunctionType(void, []), Linkage.EXTERNAL, "test_function", mod) add_string_attribute(function, "foo", "bar") # also test for non-value attributes - add_string_attribute(function, "entry_point", "") + add_string_attribute(function, "entry_point") # test behavior of empty attribute - add_string_attribute(function, "", "") + add_string_attribute(function, "") ir = str(mod) parsed_mod = Module.from_ir(Context(), ir, "test") assert str(parsed_mod) == str(mod) @@ -54,7 +58,7 @@ def test_duplicate_attr_key_replaces_previous() -> None: void = pyqir.Type.void(mod.context) function = Function(FunctionType(void, []), Linkage.EXTERNAL, "test_function", mod) add_string_attribute(function, "foo", "bar") - add_string_attribute(function, "foo", "") + add_string_attribute(function, "foo") ir = str(mod) # Tests that subsequently added attributes with the same key # replace previously added ones @@ -77,3 +81,84 @@ def test_attribute_alphabetical_sorting() -> None: # Tests that attributes are sorted alphabetically by key, # irrespective of their value assert 'attributes #0 = { "1" "A"="123" "a"="a" "b"="A" "c" }' in ir + + +def test_function_attributes_can_be_iterated_in_alphabetical_order() -> None: + mod = pyqir.Module(pyqir.Context(), "test") + void = pyqir.Type.void(mod.context) + function = Function(FunctionType(void, []), Linkage.EXTERNAL, "test_function", mod) + # add them out of order, they will be sorted automatically + add_string_attribute(function, "required_num_results", "1") + add_string_attribute(function, "entry_point") + add_string_attribute(function, "required_num_qubits", "2") + attrs: List[Attribute] = list(function.attributes.func) + assert len(attrs) == 3 + # Tests that attributes are sorted alphabetically by indexing into the list + assert attrs[0].string_kind == "entry_point" + assert attrs[0].string_value == "" + assert attrs[1].string_kind == "required_num_qubits" + assert attrs[1].string_value == "2" + assert attrs[2].string_kind == "required_num_results" + assert attrs[2].string_value == "1" + + +def test_parameter_attrs() -> None: + mod = pyqir.Module(pyqir.Context(), "test") + void = pyqir.Type.void(mod.context) + i8 = IntType(mod.context, 8) + function = Function( + FunctionType(void, [i8]), Linkage.EXTERNAL, "test_function", mod + ) + # add them out of order, they will be sorted automatically + add_string_attribute(function, "zeroext", "", 1) + add_string_attribute(function, "mycustom", "myvalue", 1) + + # params have their own AttributeSet + attrs = list(function.attributes.param(0)) + + attr = attrs[0] + assert attr.string_kind == "mycustom" + assert attr.string_value == "myvalue" + + attr = attrs[1] + assert attr.string_kind == "zeroext" + assert attr.string_value == "" + + +def test_return_attrs_can_be_added_and_read() -> None: + mod = pyqir.Module(pyqir.Context(), "test") + void = pyqir.Type.void(mod.context) + i8 = IntType(mod.context, 8) + function = Function( + FunctionType(void, [i8]), Linkage.EXTERNAL, "test_function", mod + ) + builder = Builder(mod.context) + builder.ret(None) + + add_string_attribute(function, "mycustom", "myvalue", ATTR_RETURN_INDEX) + + # params have their own AttributeSet + attrs = list(function.attributes.ret) + + attr = attrs[0] + assert attr.string_kind == "mycustom" + assert attr.string_value == "myvalue" + + +def test_explicit_function_index_attrs_can_be_added_and_read() -> None: + mod = pyqir.Module(pyqir.Context(), "test") + void = pyqir.Type.void(mod.context) + i8 = IntType(mod.context, 8) + function = Function( + FunctionType(void, [i8]), Linkage.EXTERNAL, "test_function", mod + ) + builder = Builder(mod.context) + builder.ret(None) + + add_string_attribute(function, "mycustom", "myvalue", ATTR_FUNCTION_INDEX) + + attrs = list(function.attributes.func) + + attr = attrs[0] + assert attr.string_kind == "mycustom" + assert attr.string_value == "myvalue" diff --git a/qirlib/src/values.rs b/qirlib/src/values.rs index 36be3bb8..561f7f35 100644 --- a/qirlib/src/values.rs +++ b/qirlib/src/values.rs @@ -9,7 +9,13 @@ use llvm_sys::{ core::*, prelude::*, LLVMAttributeFunctionIndex, LLVMAttributeIndex, LLVMLinkage, LLVMOpaqueAttributeRef, LLVMOpcode, LLVMTypeKind, LLVMValueKind, }; -use std::{convert::TryFrom, ffi::CStr, ptr::NonNull, str}; +use std::{ + convert::TryFrom, + ffi::CStr, + mem::{ManuallyDrop, MaybeUninit}, + ptr::NonNull, + str, +}; pub unsafe fn qubit(context: LLVMContextRef, id: u64) -> LLVMValueRef { let i64 = LLVMInt64TypeInContext(context); @@ -51,24 +57,32 @@ pub unsafe fn entry_point( let ty = LLVMFunctionType(void, [].as_mut_ptr(), 0, 0); let function = LLVMAddFunction(module, name.as_ptr(), ty); - add_string_attribute(function, b"entry_point", b""); + add_string_attribute(function, b"entry_point", b"", LLVMAttributeFunctionIndex); add_string_attribute( function, b"required_num_qubits", required_num_qubits.to_string().as_bytes(), + LLVMAttributeFunctionIndex, ); add_string_attribute( function, b"required_num_results", required_num_results.to_string().as_bytes(), + LLVMAttributeFunctionIndex, ); - add_string_attribute(function, b"qir_profiles", qir_profiles.as_bytes()); + add_string_attribute( + function, + b"qir_profiles", + qir_profiles.as_bytes(), + LLVMAttributeFunctionIndex, + ); add_string_attribute( function, b"output_labeling_schema", output_labeling_schema.as_bytes(), + LLVMAttributeFunctionIndex, ); function @@ -200,7 +214,12 @@ pub unsafe fn extract_string(value: LLVMValueRef) -> Option> { Some(data[offset..].to_vec()) } -pub unsafe fn add_string_attribute(function: LLVMValueRef, key: &[u8], value: &[u8]) { +pub unsafe fn add_string_attribute( + function: LLVMValueRef, + key: &[u8], + value: &[u8], + index: LLVMAttributeIndex, +) { let context = LLVMGetTypeContext(LLVMTypeOf(function)); let attr = LLVMCreateStringAttribute( context, @@ -209,7 +228,7 @@ pub unsafe fn add_string_attribute(function: LLVMValueRef, key: &[u8], value: &[ value.as_ptr().cast(), value.len().try_into().unwrap(), ); - LLVMAddAttributeAtIndex(function, LLVMAttributeFunctionIndex, attr); + LLVMAddAttributeAtIndex(function, index, attr); } unsafe fn get_string_attribute( @@ -225,6 +244,55 @@ unsafe fn get_string_attribute( )) } +pub unsafe fn get_attribute_count(function: LLVMValueRef, index: LLVMAttributeIndex) -> usize { + LLVMGetAttributeCountAtIndex(function, index) + .try_into() + .expect("Attribute count larger than usize.") +} + +pub unsafe fn get_string_attribute_kind(attr: *mut LLVMOpaqueAttributeRef) -> String { + let mut len = 0; + let value = LLVMGetStringAttributeKind(attr, &mut len).cast(); + let value = slice::from_raw_parts(value, len.try_into().unwrap()); + str::from_utf8(value) + .expect("Attribute kind is not valid UTF-8.") + .to_string() +} + +pub unsafe fn get_string_attribute_value(attr: *mut LLVMOpaqueAttributeRef) -> Option { + if LLVMIsStringAttribute(attr) == 0 { + None + } else { + let mut len = 0; + let value = LLVMGetStringAttributeValue(attr, &mut len).cast(); + let value = slice::from_raw_parts(value, len.try_into().unwrap()); + Some( + str::from_utf8(value) + .expect("Attribute kind is not valid UTF-8.") + .to_string(), + ) + } +} + +pub unsafe fn get_attributes( + function: LLVMValueRef, + index: LLVMAttributeIndex, +) -> Vec<*mut LLVMOpaqueAttributeRef> { + let count = get_attribute_count(function, index); + if count == 0 { + return Vec::new(); + } + let attrs: Vec> = Vec::with_capacity(count); + let mut attrs = ManuallyDrop::new(attrs); + for _ in 0..count { + attrs.push(MaybeUninit::uninit()); + } + + LLVMGetAttributesAtIndex(function, index, attrs.as_mut_ptr().cast()); + + Vec::from_raw_parts(attrs.as_mut_ptr().cast(), attrs.len(), attrs.capacity()) +} + unsafe fn pointer_to_int(value: LLVMValueRef) -> Option { let ty = LLVMTypeOf(value); if LLVMGetTypeKind(ty) == LLVMTypeKind::LLVMPointerTypeKind && LLVMIsConstant(value) != 0 { @@ -270,3 +338,222 @@ mod tests { assert_reference_ir("module/many_required_qubits_results", 5, 7, |_| ()); } } + +#[cfg(test)] +mod string_attribute_tests { + use std::ffi::CString; + + use llvm_sys::{ + core::{ + LLVMAddFunction, LLVMBuildRetVoid, LLVMContextCreate, LLVMContextDispose, + LLVMCreateBuilderInContext, LLVMDisposeModule, LLVMModuleCreateWithNameInContext, + LLVMVoidTypeInContext, + }, + LLVMAttributeFunctionIndex, LLVMAttributeIndex, LLVMAttributeReturnIndex, LLVMContext, + LLVMModule, LLVMValue, + }; + + use crate::values::get_attributes; + + use super::{add_string_attribute, get_attribute_count}; + + fn setup_expect( + setup: impl Fn(*mut LLVMContext, *mut LLVMModule, *mut LLVMValue), + expect: impl Fn(*mut LLVMValue), + ) { + unsafe { + let context = LLVMContextCreate(); + let module_name = CString::new("test_module").unwrap(); + let module = LLVMModuleCreateWithNameInContext(module_name.as_ptr(), context); + let function_name = CString::new("test_func").unwrap(); + let function = LLVMAddFunction( + module, + function_name.as_ptr(), + LLVMVoidTypeInContext(context), + ); + let builder = LLVMCreateBuilderInContext(context); + LLVMBuildRetVoid(builder); + setup(context, module, function); + expect(function); + LLVMDisposeModule(module); + LLVMContextDispose(context); + } + } + #[test] + fn get_attribute_count_works_when_function_attrs_exist() { + unsafe { + setup_expect( + |_, _, function| { + add_string_attribute(function, b"entry_point", b"", LLVMAttributeFunctionIndex); + add_string_attribute( + function, + b"required_num_qubits", + b"1", + LLVMAttributeFunctionIndex, + ); + add_string_attribute( + function, + b"required_num_results", + b"2", + LLVMAttributeFunctionIndex, + ); + add_string_attribute( + function, + b"qir_profiles", + b"test", + LLVMAttributeFunctionIndex, + ); + }, + |f| { + let count = get_attribute_count(f, LLVMAttributeFunctionIndex); + assert!(count == 4); + }, + ); + } + } + #[test] + fn attributes_with_kind_only_have_empty_string_values() { + unsafe { + setup_expect( + |_, _, function| { + add_string_attribute(function, b"entry_point", b"", LLVMAttributeFunctionIndex); + }, + |f| { + let count = get_attribute_count(f, LLVMAttributeFunctionIndex); + assert!(count == 1); + let attrs = get_attributes(f, LLVMAttributeFunctionIndex); + for attr in attrs { + if let Some(value) = super::get_string_attribute_value(attr) { + assert_eq!(value, ""); + } else { + panic!("Should have a value"); + } + } + }, + ); + } + } + #[test] + fn attributes_with_kind_only_have_key_matching_kind() { + unsafe { + setup_expect( + |_, _, function| { + add_string_attribute(function, b"entry_point", b"", LLVMAttributeFunctionIndex); + }, + |f| { + let count = get_attribute_count(f, LLVMAttributeFunctionIndex); + assert!(count == 1); + let attrs = get_attributes(f, LLVMAttributeFunctionIndex); + for attr in attrs { + assert_eq!(super::get_string_attribute_kind(attr), "entry_point"); + } + }, + ); + } + } + #[test] + fn attributes_with_key_and_value_have_matching_kind_and_value() { + unsafe { + setup_expect( + |_, _, function| { + add_string_attribute( + function, + b"qir_profiles", + b"test", + LLVMAttributeFunctionIndex, + ); + }, + |f| { + let count = get_attribute_count(f, LLVMAttributeFunctionIndex); + assert!(count == 1); + let attrs = get_attributes(f, LLVMAttributeFunctionIndex); + for attr in attrs { + assert_eq!(super::get_string_attribute_kind(attr), "qir_profiles"); + assert!(super::get_string_attribute_value(attr).is_some()); + assert_eq!(super::get_string_attribute_value(attr).unwrap(), "test"); + } + }, + ); + } + } + #[test] + fn get_attribute_count_works_when_function_attrs_dont_exist() { + unsafe { + setup_expect( + |_, _, _| {}, + |f| { + let count = get_attribute_count(f, LLVMAttributeFunctionIndex); + assert!(count == 0); + }, + ); + } + } + #[test] + fn get_attribute_count_works_when_return_attrs_dont_exist() { + unsafe { + setup_expect( + |_, _, _| {}, + |f| { + let count = get_attribute_count(f, LLVMAttributeReturnIndex); + assert!(count == 0); + }, + ); + } + } + #[test] + fn get_attribute_count_works_when_param_attrs_dont_exist() { + unsafe { + setup_expect( + |_, _, _| {}, + |f| { + const INVALID_PARAM_ID: LLVMAttributeIndex = 1; + let count = get_attribute_count(f, INVALID_PARAM_ID); + assert!(count == 0); + }, + ); + } + } + #[test] + fn iteration_works_when_function_attrs_dont_exist() { + unsafe { + setup_expect( + |_, _, _| {}, + |f| { + let attrs = get_attributes(f, LLVMAttributeFunctionIndex); + for _ in attrs { + panic!("Should not have any attributes") + } + }, + ); + } + } + #[test] + fn iteration_works_when_return_attrs_dont_exist() { + unsafe { + setup_expect( + |_, _, _| {}, + |f| { + let attrs = get_attributes(f, LLVMAttributeReturnIndex); + for _ in attrs { + panic!("Should not have any attributes") + } + }, + ); + } + } + #[test] + fn iteration_works_when_param_attrs_dont_exist() { + unsafe { + setup_expect( + |_, _, _| {}, + |f| { + const INVALID_PARAM_ID: LLVMAttributeIndex = 1; + let attrs = get_attributes(f, INVALID_PARAM_ID); + for _ in attrs { + panic!("Should not have any attributes") + } + }, + ); + } + } +}