Skip to content

Commit

Permalink
Add API for inserting phi nodes (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
ausbin authored Jun 27, 2024
1 parent 2eb874c commit c536c1b
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 2 deletions.
24 changes: 24 additions & 0 deletions pyqir/pyqir/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,26 @@ class Builder:
"""
...

def condbr(self, if_: Value, then: BasicBlock, else_: BasicBlock) -> Instruction:
"""
Inserts an conditional branch instruction.
:param if_: The condition
:param then: The destination block if condition is 1
:param else_: The destination block if condition is 0
:returns: The branch instruction.
"""
...

def phi(self, value: Type) -> Phi:
"""
Inserts a phi node.
:param type: The type of the phi node
:returns: The phi node.
"""
...

def ret(self, value: Optional[Value]) -> Instruction:
"""
Inserts a return instruction.
Expand Down Expand Up @@ -739,6 +759,10 @@ class Opcode(Enum):
class Phi(Instruction):
"""A phi node instruction."""

def add_incoming(self, value: Value, block: BasicBlock) -> None:
"""Adds an incoming value to the end of the phi list."""
...

@property
def incoming(self) -> List[Tuple[Value, BasicBlock]]:
"""The incoming values and their preceding basic blocks."""
Expand Down
15 changes: 13 additions & 2 deletions pyqir/pyqir/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def __init__(
self._num_qubits = num_qubits
self._num_results = num_results

entry_point = pyqir.entry_point(
self._entry_point = pyqir.entry_point(
self._module, entry_point_name, num_qubits, num_results
)
self._builder.insert_at_end(BasicBlock(context, "entry", entry_point))
self._entry_block = BasicBlock(context, "entry", self._entry_point)
self._builder.insert_at_end(self._entry_block)

@property
def context(self) -> Context:
Expand All @@ -84,6 +85,16 @@ def builder(self) -> Builder:
"""The instruction builder."""
return self._builder

@property
def entry_point(self) -> Function:
"""The entry point function (automatically generated)."""
return self._entry_point

@property
def entry_block(self) -> BasicBlock:
"""The first basic block of the entry point (automatically generated)."""
return self._entry_block

def add_external_function(self, name: str, ty: FunctionType) -> Function:
"""
Adds a declaration for an externally linked function to the module.
Expand Down
49 changes: 49 additions & 0 deletions pyqir/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use crate::{
core::Context,
instructions::IntPredicate,
types::Type,
values::{BasicBlock, Literal, Owner, Value},
};
use const_str::raw_cstr;
Expand Down Expand Up @@ -290,6 +291,54 @@ impl Builder {
}
}

/// Inserts an conditional branch instruction.
///
/// :param BasicBlock if_: The condition
/// :param BasicBlock then: The destination block if condition is 1
/// :param BasicBlock else_: The destination block if condition is 0
/// :returns: The branch instruction.
/// :rtype: Instruction
#[pyo3(text_signature = "(if_, then, else_)")]
fn condbr(
&self,
py: Python,
if_: &Value,
then: PyRef<BasicBlock>,
else_: PyRef<BasicBlock>,
) -> PyResult<PyObject> {
let owner = Owner::merge(
py,
[
&self.owner,
if_.owner(),
then.as_ref().owner(),
else_.as_ref().owner(),
],
)?;
unsafe {
let value = LLVMBuildCondBr(
self.builder.as_ptr(),
if_.as_ptr(),
then.as_ptr(),
else_.as_ptr(),
);
Value::from_raw(py, owner, value)
}
}

/// Inserts a phi node.
///
/// :returns: The phi node.
/// :rtype: Instruction
#[pyo3(text_signature = "(type)")]
fn phi(&self, py: Python, r#type: &Type) -> PyResult<PyObject> {
unsafe {
let owner = self.owner.clone_ref(py);
let value = LLVMBuildPhi(self.builder.as_ptr(), r#type.as_ptr(), raw_cstr!(""));
Value::from_raw(py, owner, value)
}
}

/// Inserts a return instruction.
///
/// :param Value value: The value to return. If `None`, returns void.
Expand Down
9 changes: 9 additions & 0 deletions pyqir/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,15 @@ pub(crate) struct Phi;

#[pymethods]
impl Phi {
/// Adds an incoming value to the end of the phi list.
#[pyo3(text_signature = "(value, block)")]
fn add_incoming(slf: PyRef<Self>, value: &Value, block: PyRef<BasicBlock>) {
let slf = slf.into_super().into_super();
unsafe {
LLVMAddIncoming(slf.as_ptr(), &mut value.as_ptr(), &mut block.as_ptr(), 1);
}
}

/// The incoming values and their preceding basic blocks.
///
/// :type: typing.List[typing.Tuple[Value, BasicBlock]]
Expand Down
24 changes: 24 additions & 0 deletions pyqir/tests/resources/test_phi_add.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
; ModuleID = 'phi_add'
source_filename = "phi_add"

define void @main() #0 {
entry:
br i1 true, label %body, label %footer

body: ; preds = %entry
br label %footer

footer: ; preds = %body, %entry
%0 = phi i32 [ 2, %entry ], [ 3, %body ]
%1 = add i32 %0, 1
ret void
}

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="custom" "required_num_qubits"="1" "required_num_results"="1" }

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
23 changes: 23 additions & 0 deletions pyqir/tests/resources/test_phi_constants.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; ModuleID = 'phi_constants'
source_filename = "phi_constants"

define void @main() #0 {
entry:
br i1 false, label %body, label %footer

body: ; preds = %entry
br label %footer

footer: ; preds = %body, %entry
%0 = phi i32 [ 4, %body ], [ 100, %entry ]
ret void
}

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="custom" "required_num_qubits"="1" "required_num_results"="1" }

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
76 changes: 76 additions & 0 deletions pyqir/tests/test_phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os

import pytest
from pathlib import Path

import pyqir


def test_constants() -> None:
module = pyqir.SimpleModule("phi_constants", 1, 1)
context = module.context
builder = module.builder
entry_point = module.entry_point

entry = module.entry_block
body = pyqir.BasicBlock(context, "body", entry_point)
footer = pyqir.BasicBlock(context, "footer", entry_point)

builder.insert_at_end(entry)
const0 = pyqir.Constant.null(pyqir.IntType(context, 1))
builder.condbr(const0, body, footer)

builder.insert_at_end(body)
builder.br(footer)

builder.insert_at_end(footer)
i32 = pyqir.IntType(context, 32)
phi = builder.phi(i32)
const_taken = pyqir.const(i32, 4)
const_not_taken = pyqir.const(i32, 100)
phi.add_incoming(const_taken, body)
phi.add_incoming(const_not_taken, entry)

ir = module.ir()

file = os.path.join(os.path.dirname(__file__), "resources/test_phi_constants.ll")
expected = Path(file).read_text()
assert ir == expected


def test_add() -> None:
module = pyqir.SimpleModule("phi_add", 1, 1)
context = module.context
builder = module.builder
entry_point = module.entry_point

entry = module.entry_block
body = pyqir.BasicBlock(context, "body", entry_point)
footer = pyqir.BasicBlock(context, "footer", entry_point)

builder.insert_at_end(entry)
i32 = pyqir.IntType(context, 32)
const1 = pyqir.const(i32, 1)
const2 = pyqir.const(i32, 2)
sum_two = builder.add(const1, const1)
cmp = builder.icmp(pyqir.IntPredicate.EQ, sum_two, const2)
builder.condbr(cmp, body, footer)

builder.insert_at_end(body)
sum_three = builder.add(sum_two, const1)
builder.br(footer)

builder.insert_at_end(footer)
phi = builder.phi(i32)
phi.add_incoming(sum_two, entry)
phi.add_incoming(sum_three, body)
sum_four = builder.add(phi, const1)

ir = module.ir()

file = os.path.join(os.path.dirname(__file__), "resources/test_phi_add.ll")
expected = Path(file).read_text()
assert ir == expected

0 comments on commit c536c1b

Please sign in to comment.