Skip to content

Commit

Permalink
Generalize API for attributes (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
bamarsha authored Dec 5, 2022
1 parent a7c6b29 commit 7a1c13f
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 44 deletions.
4 changes: 4 additions & 0 deletions pyqir/pyqir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pyqir._native import (
ArrayType,
Attribute,
AttributeList,
AttributeSet,
BasicBlock,
BasicQisBuilder,
Builder,
Expand Down Expand Up @@ -51,6 +53,8 @@
__all__ = [
"ArrayType",
"Attribute",
"AttributeList",
"AttributeSet",
"BasicBlock",
"BasicQisBuilder",
"Builder",
Expand Down
54 changes: 45 additions & 9 deletions pyqir/pyqir/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,48 @@ class Attribute:
"""An attribute."""

@property
def value(self) -> str:
"""The value of the attribute as a string."""
def string_value(self) -> Optional[str]:
"""The value of this attribute as a string, or `None` if this is not a string attribute."""
...

class AttributeList:
"""The attribute list for a function."""

def param(self, n: int) -> AttributeSet:
"""
The attributes for a parameter.
:param n: The parameter number, starting from zero.
:returns: The parameter attributes.
"""
...
@property
def ret(self) -> AttributeSet:
"""The attributes for the return type."""
...
@property
def func(self) -> AttributeSet:
"""The attributes for the function itself."""
...

class AttributeSet:
"""A set of attributes for a specific part of a function."""

def __contains__(self, item: str) -> bool:
"""
Tests if an attribute is a member of the set.
:param item: The attribute kind.
:returns: True if the group has an attribute with the given kind.
"""
...
def __getitem__(self, key: str) -> Attribute:
"""
Gets an attribute based on its kind.
:param key: The attribute kind.
:returns: The attribute.
"""
...

class BasicBlock(Value):
Expand Down Expand Up @@ -443,13 +483,9 @@ class Function(Constant):
def basic_blocks(self) -> List[BasicBlock]:
"""The basic blocks in this function."""
...
def attribute(self, name: str) -> Optional[Attribute]:
"""
Gets an attribute of this function with the given name if it has one.
:param name: The name of the attribute.
:returns: The attribute.
"""
@property
def attributes(self) -> AttributeList:
"""The attributes for this function."""
...

class FunctionType(Type):
Expand Down
18 changes: 0 additions & 18 deletions pyqir/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,3 @@ impl From<Linkage> for inkwell::module::Linkage {
}
}
}

/// An attribute.
#[pyclass(unsendable)]
pub(crate) struct Attribute(pub(crate) inkwell::attributes::Attribute);

#[pymethods]
impl Attribute {
/// The value of the attribute as a string.
///
/// :type: str
#[getter]
fn value(&self) -> &str {
self.0
.get_string_value()
.to_str()
.expect("Value is not valid UTF-8.")
}
}
7 changes: 5 additions & 2 deletions pyqir/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
instructions::{
Call, FCmp, FloatPredicate, ICmp, Instruction, IntPredicate, Opcode, Phi, Switch,
},
module::{Attribute, Linkage, Module},
module::{Linkage, Module},
qis::BasicQisBuilder,
types::{
is_qubit_type, is_result_type, qubit_type, result_type, ArrayType, FunctionType, IntType,
Expand All @@ -17,7 +17,8 @@ use crate::{
values::{
entry_point, extract_byte_string, global_byte_string, is_entry_point, is_interop_friendly,
qubit, qubit_id, r#const, required_num_qubits, required_num_results, result, result_id,
BasicBlock, Constant, FloatConstant, Function, IntConstant, Value,
Attribute, AttributeList, AttributeSet, BasicBlock, Constant, FloatConstant, Function,
IntConstant, Value,
},
};
use pyo3::prelude::*;
Expand All @@ -26,6 +27,8 @@ use pyo3::prelude::*;
fn _native(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<ArrayType>()?;
m.add_class::<Attribute>()?;
m.add_class::<AttributeList>()?;
m.add_class::<AttributeSet>()?;
m.add_class::<BasicBlock>()?;
m.add_class::<BasicQisBuilder>()?;
m.add_class::<Builder>()?;
Expand Down
116 changes: 104 additions & 12 deletions pyqir/src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use crate::{
context::Context,
instructions::Instruction,
module::{Attribute, Linkage, Module},
module::{Linkage, Module},
types::{FunctionType, Type},
};
use inkwell::{
Expand All @@ -28,7 +28,7 @@ use llvm_sys::{
};
use pyo3::{
conversion::ToPyObject,
exceptions::{PyTypeError, PyValueError},
exceptions::{PyKeyError, PyTypeError, PyValueError},
prelude::*,
types::{PyBytes, PyLong},
PyRef,
Expand Down Expand Up @@ -443,16 +443,10 @@ impl Function {
.collect()
}

/// Gets an attribute of this function with the given name if it has one.
///
/// :param str name: The name of the attribute.
/// :returns: The attribute.
/// :rtype: typing.Optional[Attribute]
#[pyo3(text_signature = "(name)")]
fn attribute(&self, name: &str) -> Option<Attribute> {
Some(Attribute(
self.0.get_string_attribute(AttributeLoc::Function, name)?,
))
/// The attributes for this function.
#[getter]
fn attributes(slf: Py<Function>) -> AttributeList {
AttributeList(slf)
}
}

Expand All @@ -462,6 +456,104 @@ impl Function {
}
}

/// An attribute.
#[pyclass(unsendable)]
pub(crate) struct Attribute(pub(crate) inkwell::attributes::Attribute);

#[pymethods]
impl Attribute {
/// The value of this attribute as a string, or `None` if this is not a string attribute.
///
/// :type: typing.Optional[str]
#[getter]
fn string_value(&self) -> Option<&str> {
if self.0.is_string() {
Some(
self.0
.get_string_value()
.to_str()
.expect("Value is not valid UTF-8."),
)
} else {
None
}
}
}

/// The attribute list for a function.
#[pyclass]
pub(crate) struct AttributeList(Py<Function>);

#[pymethods]
impl AttributeList {
/// The attributes for a parameter.
///
/// :param int n: The parameter number, starting from zero.
/// :returns: The parameter attributes.
/// :rtype: AttributeDict
fn param(&self, py: Python, n: u32) -> AttributeSet {
AttributeSet {
function: self.0.clone_ref(py),
index: AttributeLoc::Param(n),
}
}

/// The attributes for the return type.
///
/// :type: AttributeDict
#[getter]
fn ret(&self, py: Python) -> AttributeSet {
AttributeSet {
function: self.0.clone_ref(py),
index: AttributeLoc::Return,
}
}

/// The attributes for the function itself.
///
/// :type: AttributeDict
#[getter]
fn func(&self, py: Python) -> AttributeSet {
AttributeSet {
function: self.0.clone_ref(py),
index: AttributeLoc::Function,
}
}
}

/// A set of attributes for a specific part of a function.
#[pyclass]
pub(crate) struct AttributeSet {
function: Py<Function>,
index: AttributeLoc,
}

#[pymethods]
impl AttributeSet {
/// Tests if an attribute is a member of the set.
///
/// :param str item: The attribute kind.
/// :returns: True if the group has an attribute with the given kind.
/// :rtype: bool
fn __contains__(&self, py: Python, item: &str) -> bool {
unsafe { self.function.borrow(py).get() }
.get_string_attribute(self.index, item)
.is_some()
}

/// Gets an attribute based on its kind.
///
/// :param str key: The attribute kind.
/// :returns: The attribute.
/// :rtype: Attribute
fn __getitem__(&self, py: Python, key: &str) -> PyResult<Attribute> {
unsafe { self.function.borrow(py).get() }
.get_string_attribute(self.index, key)
.map(Attribute)
.ok_or_else(|| PyKeyError::new_err(key.to_owned()))
}
}

#[derive(Clone, Copy)]
pub(crate) enum AnyValue<'ctx> {
Any(AnyValueEnum<'ctx>),
Expand Down
6 changes: 6 additions & 0 deletions pyqir/tests/attributes.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
; ModuleID = 'attributes'
source_filename = "attributes"

declare "ret_attr"="ret value" i1 @foo(i64 "param0_attr"="param0 value" %0, double %1, i8* "param2_attr"="param2 value" %2) #0

attributes #0 = { "fn_attr"="fn value" }
32 changes: 29 additions & 3 deletions pyqir/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,8 @@ def test_parser_internals() -> None:
assert interop_funcs[0].name == func_name
assert required_num_qubits(interop_funcs[0]) == 6

attribute = interop_funcs[0].attribute("requiredQubits")
assert attribute is not None
assert attribute.value == "6"
attribute = interop_funcs[0].attributes.func["requiredQubits"]
assert attribute.string_value == "6"

blocks = func.basic_blocks
assert len(blocks) == 9
Expand Down Expand Up @@ -259,3 +258,30 @@ def test_parser_internals() -> None:

assert not entry_block.instructions[10].type.is_void
assert entry_block.instructions[10].name == ""


def test_attribute_values() -> None:
ir = Path("tests/attributes.ll").read_text()
module = Module.from_ir(Context(), ir)
attributes = module.functions[0].attributes
assert attributes.ret["ret_attr"].string_value == "ret value"
assert attributes.param(0)["param0_attr"].string_value == "param0 value"
assert attributes.param(2)["param2_attr"].string_value == "param2 value"
assert attributes.func["fn_attr"].string_value == "fn value"


def test_contains_attribute() -> None:
ir = Path("tests/attributes.ll").read_text()
module = Module.from_ir(Context(), ir)
attributes = module.functions[0].attributes
assert "ret_attr" in attributes.ret
assert attributes.ret["ret_attr"] is not None


def test_not_contains_attribute() -> None:
ir = Path("tests/attributes.ll").read_text()
module = Module.from_ir(Context(), ir)
attributes = module.functions[0].attributes
assert "foo" not in attributes.ret
with pytest.raises(KeyError):
attributes.ret["foo"]

0 comments on commit 7a1c13f

Please sign in to comment.