From 7a1c13fbda10bb5621d01e578aa6d4121aa0048b Mon Sep 17 00:00:00 2001 From: Sarah Marshall <33814365+samarsha@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:09:09 -0800 Subject: [PATCH] Generalize API for attributes (#211) --- pyqir/pyqir/__init__.py | 4 ++ pyqir/pyqir/_native.pyi | 54 ++++++++++++++--- pyqir/src/module.rs | 18 ------ pyqir/src/python.rs | 7 ++- pyqir/src/values.rs | 116 +++++++++++++++++++++++++++++++++---- pyqir/tests/attributes.ll | 6 ++ pyqir/tests/test_parser.py | 32 +++++++++- 7 files changed, 193 insertions(+), 44 deletions(-) create mode 100644 pyqir/tests/attributes.ll diff --git a/pyqir/pyqir/__init__.py b/pyqir/pyqir/__init__.py index 366534f1..3f281ca6 100644 --- a/pyqir/pyqir/__init__.py +++ b/pyqir/pyqir/__init__.py @@ -4,6 +4,8 @@ from pyqir._native import ( ArrayType, Attribute, + AttributeList, + AttributeSet, BasicBlock, BasicQisBuilder, Builder, @@ -51,6 +53,8 @@ __all__ = [ "ArrayType", "Attribute", + "AttributeList", + "AttributeSet", "BasicBlock", "BasicQisBuilder", "Builder", diff --git a/pyqir/pyqir/_native.pyi b/pyqir/pyqir/_native.pyi index 654bdcc4..33be760e 100644 --- a/pyqir/pyqir/_native.pyi +++ b/pyqir/pyqir/_native.pyi @@ -21,8 +21,48 @@ class Attribute: """An attribute.""" @property - def value(self) -> str: - """The value of the attribute as a string.""" + def string_value(self) -> Optional[str]: + """The value of this attribute as a string, or `None` if this is not a string attribute.""" + ... + +class AttributeList: + """The attribute list for a function.""" + + def param(self, n: int) -> AttributeSet: + """ + The attributes for a parameter. + + :param n: The parameter number, starting from zero. + :returns: The parameter attributes. + """ + ... + @property + def ret(self) -> AttributeSet: + """The attributes for the return type.""" + ... + @property + def func(self) -> AttributeSet: + """The attributes for the function itself.""" + ... + +class AttributeSet: + """A set of attributes for a specific part of a function.""" + + def __contains__(self, item: str) -> bool: + """ + Tests if an attribute is a member of the set. + + :param item: The attribute kind. + :returns: True if the group has an attribute with the given kind. + """ + ... + def __getitem__(self, key: str) -> Attribute: + """ + Gets an attribute based on its kind. + + :param key: The attribute kind. + :returns: The attribute. + """ ... class BasicBlock(Value): @@ -443,13 +483,9 @@ class Function(Constant): def basic_blocks(self) -> List[BasicBlock]: """The basic blocks in this function.""" ... - def attribute(self, name: str) -> Optional[Attribute]: - """ - Gets an attribute of this function with the given name if it has one. - - :param name: The name of the attribute. - :returns: The attribute. - """ + @property + def attributes(self) -> AttributeList: + """The attributes for this function.""" ... class FunctionType(Type): diff --git a/pyqir/src/module.rs b/pyqir/src/module.rs index ec9ee82b..2dc870b3 100644 --- a/pyqir/src/module.rs +++ b/pyqir/src/module.rs @@ -201,21 +201,3 @@ impl From for inkwell::module::Linkage { } } } - -/// An attribute. -#[pyclass(unsendable)] -pub(crate) struct Attribute(pub(crate) inkwell::attributes::Attribute); - -#[pymethods] -impl Attribute { - /// The value of the attribute as a string. - /// - /// :type: str - #[getter] - fn value(&self) -> &str { - self.0 - .get_string_value() - .to_str() - .expect("Value is not valid UTF-8.") - } -} diff --git a/pyqir/src/python.rs b/pyqir/src/python.rs index 66243492..bf37e2dc 100644 --- a/pyqir/src/python.rs +++ b/pyqir/src/python.rs @@ -8,7 +8,7 @@ use crate::{ instructions::{ Call, FCmp, FloatPredicate, ICmp, Instruction, IntPredicate, Opcode, Phi, Switch, }, - module::{Attribute, Linkage, Module}, + module::{Linkage, Module}, qis::BasicQisBuilder, types::{ is_qubit_type, is_result_type, qubit_type, result_type, ArrayType, FunctionType, IntType, @@ -17,7 +17,8 @@ use crate::{ values::{ entry_point, extract_byte_string, global_byte_string, is_entry_point, is_interop_friendly, qubit, qubit_id, r#const, required_num_qubits, required_num_results, result, result_id, - BasicBlock, Constant, FloatConstant, Function, IntConstant, Value, + Attribute, AttributeList, AttributeSet, BasicBlock, Constant, FloatConstant, Function, + IntConstant, Value, }, }; use pyo3::prelude::*; @@ -26,6 +27,8 @@ use pyo3::prelude::*; fn _native(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/pyqir/src/values.rs b/pyqir/src/values.rs index 40c7c021..aea63eaf 100644 --- a/pyqir/src/values.rs +++ b/pyqir/src/values.rs @@ -6,7 +6,7 @@ use crate::{ context::Context, instructions::Instruction, - module::{Attribute, Linkage, Module}, + module::{Linkage, Module}, types::{FunctionType, Type}, }; use inkwell::{ @@ -28,7 +28,7 @@ use llvm_sys::{ }; use pyo3::{ conversion::ToPyObject, - exceptions::{PyTypeError, PyValueError}, + exceptions::{PyKeyError, PyTypeError, PyValueError}, prelude::*, types::{PyBytes, PyLong}, PyRef, @@ -443,16 +443,10 @@ impl Function { .collect() } - /// Gets an attribute of this function with the given name if it has one. - /// - /// :param str name: The name of the attribute. - /// :returns: The attribute. - /// :rtype: typing.Optional[Attribute] - #[pyo3(text_signature = "(name)")] - fn attribute(&self, name: &str) -> Option { - Some(Attribute( - self.0.get_string_attribute(AttributeLoc::Function, name)?, - )) + /// The attributes for this function. + #[getter] + fn attributes(slf: Py) -> AttributeList { + AttributeList(slf) } } @@ -462,6 +456,104 @@ impl Function { } } +/// An attribute. +#[pyclass(unsendable)] +pub(crate) struct Attribute(pub(crate) inkwell::attributes::Attribute); + +#[pymethods] +impl Attribute { + /// The value of this attribute as a string, or `None` if this is not a string attribute. + /// + /// :type: typing.Optional[str] + #[getter] + fn string_value(&self) -> Option<&str> { + if self.0.is_string() { + Some( + self.0 + .get_string_value() + .to_str() + .expect("Value is not valid UTF-8."), + ) + } else { + None + } + } +} + +/// The attribute list for a function. +#[pyclass] +pub(crate) struct AttributeList(Py); + +#[pymethods] +impl AttributeList { + /// The attributes for a parameter. + /// + /// :param int n: The parameter number, starting from zero. + /// :returns: The parameter attributes. + /// :rtype: AttributeDict + fn param(&self, py: Python, n: u32) -> AttributeSet { + AttributeSet { + function: self.0.clone_ref(py), + index: AttributeLoc::Param(n), + } + } + + /// The attributes for the return type. + /// + /// :type: AttributeDict + #[getter] + fn ret(&self, py: Python) -> AttributeSet { + AttributeSet { + function: self.0.clone_ref(py), + index: AttributeLoc::Return, + } + } + + /// The attributes for the function itself. + /// + /// :type: AttributeDict + #[getter] + fn func(&self, py: Python) -> AttributeSet { + AttributeSet { + function: self.0.clone_ref(py), + index: AttributeLoc::Function, + } + } +} + +/// A set of attributes for a specific part of a function. +#[pyclass] +pub(crate) struct AttributeSet { + function: Py, + index: AttributeLoc, +} + +#[pymethods] +impl AttributeSet { + /// Tests if an attribute is a member of the set. + /// + /// :param str item: The attribute kind. + /// :returns: True if the group has an attribute with the given kind. + /// :rtype: bool + fn __contains__(&self, py: Python, item: &str) -> bool { + unsafe { self.function.borrow(py).get() } + .get_string_attribute(self.index, item) + .is_some() + } + + /// Gets an attribute based on its kind. + /// + /// :param str key: The attribute kind. + /// :returns: The attribute. + /// :rtype: Attribute + fn __getitem__(&self, py: Python, key: &str) -> PyResult { + unsafe { self.function.borrow(py).get() } + .get_string_attribute(self.index, key) + .map(Attribute) + .ok_or_else(|| PyKeyError::new_err(key.to_owned())) + } +} + #[derive(Clone, Copy)] pub(crate) enum AnyValue<'ctx> { Any(AnyValueEnum<'ctx>), diff --git a/pyqir/tests/attributes.ll b/pyqir/tests/attributes.ll new file mode 100644 index 00000000..1cd69206 --- /dev/null +++ b/pyqir/tests/attributes.ll @@ -0,0 +1,6 @@ +; ModuleID = 'attributes' +source_filename = "attributes" + +declare "ret_attr"="ret value" i1 @foo(i64 "param0_attr"="param0 value" %0, double %1, i8* "param2_attr"="param2 value" %2) #0 + +attributes #0 = { "fn_attr"="fn value" } diff --git a/pyqir/tests/test_parser.py b/pyqir/tests/test_parser.py index b9abf33a..d803f2a4 100644 --- a/pyqir/tests/test_parser.py +++ b/pyqir/tests/test_parser.py @@ -186,9 +186,8 @@ def test_parser_internals() -> None: assert interop_funcs[0].name == func_name assert required_num_qubits(interop_funcs[0]) == 6 - attribute = interop_funcs[0].attribute("requiredQubits") - assert attribute is not None - assert attribute.value == "6" + attribute = interop_funcs[0].attributes.func["requiredQubits"] + assert attribute.string_value == "6" blocks = func.basic_blocks assert len(blocks) == 9 @@ -259,3 +258,30 @@ def test_parser_internals() -> None: assert not entry_block.instructions[10].type.is_void assert entry_block.instructions[10].name == "" + + +def test_attribute_values() -> None: + ir = Path("tests/attributes.ll").read_text() + module = Module.from_ir(Context(), ir) + attributes = module.functions[0].attributes + assert attributes.ret["ret_attr"].string_value == "ret value" + assert attributes.param(0)["param0_attr"].string_value == "param0 value" + assert attributes.param(2)["param2_attr"].string_value == "param2 value" + assert attributes.func["fn_attr"].string_value == "fn value" + + +def test_contains_attribute() -> None: + ir = Path("tests/attributes.ll").read_text() + module = Module.from_ir(Context(), ir) + attributes = module.functions[0].attributes + assert "ret_attr" in attributes.ret + assert attributes.ret["ret_attr"] is not None + + +def test_not_contains_attribute() -> None: + ir = Path("tests/attributes.ll").read_text() + module = Module.from_ir(Context(), ir) + attributes = module.functions[0].attributes + assert "foo" not in attributes.ret + with pytest.raises(KeyError): + attributes.ret["foo"]