Skip to content

Commit

Permalink
Merge branch 'main' into ch-kernel.builder.reset
Browse files Browse the repository at this point in the history
  • Loading branch information
schweitzpgi authored Jul 17, 2024
2 parents c33b8ab + a674f66 commit 892fd9c
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 24 deletions.
12 changes: 11 additions & 1 deletion python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
# Instantiation of all Python modules
################################################################################

file(GLOB_RECURSE PYTHON_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.py")

add_custom_target(
CopyPythonFiles ALL
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_BINARY_DIR}/python
DEPENDS ${PYTHON_SOURCES}
)

add_mlir_python_modules(CUDAQuantumPythonModules
ROOT_PREFIX "${MLIR_BINARY_DIR}/python/cudaq/mlir"
INSTALL_PREFIX "cudaq/mlir"
Expand All @@ -145,7 +155,7 @@ add_mlir_python_modules(CUDAQuantumPythonModules
CUDAQuantumPythonCAPI
)

file (COPY cudaq DESTINATION ${CMAKE_BINARY_DIR}/python)
add_dependencies(CUDAQuantumPythonModules CopyPythonFiles)

## The Python bindings module for Quake dialect depends on CUDAQ libraries
## which it can't locate since they are in "../../lib" and the 'rpath' is set
Expand Down
113 changes: 99 additions & 14 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ def visit_Call(self, node):
self.visit(keyword.value)
namedArgs[keyword.arg] = self.popValue()

if node.func.id == "len":
if node.func.id == 'len':
listVal = self.ifPointerThenLoad(self.popValue())
if cc.StdvecType.isinstance(listVal.type):
self.pushValue(
Expand All @@ -1305,7 +1305,7 @@ def visit_Call(self, node):
self.emitFatalError(
"__len__ not supported on variables of this type.", node)

if node.func.id == "range":
if node.func.id == 'range':
startVal, endVal, stepVal, isDecrementing = self.__processRangeLoopIterationBounds(
node.args)

Expand Down Expand Up @@ -1361,7 +1361,7 @@ def bodyBuilder(iterVar):
self.pushValue(totalSize)
return

if node.func.id == "enumerate":
if node.func.id == 'enumerate':
# We have to have something "iterable" on the stack,
# could be coming from `range()` or an iterable like `qvector`
totalSize = None
Expand Down Expand Up @@ -1403,11 +1403,12 @@ def extractFunctor(idxVal):
"could not infer enumerate tuple type ({})".format(
iterable.type), node)
else:
# FIXME this should be an `emitFatalError`
assert len(
self.valueStack
) == 2, 'Error in AST processing, should have 2 values on the stack for enumerate {}'.format(
ast.unparse(node) if hasattr(ast, 'unparse') else node)
if len(self.valueStack) != 2:
msg = 'Error in AST processing, should have 2 values on the stack for enumerate {}'.format(
ast.unparse(node) if hasattr(ast, 'unparse'
) else node)
self.emitFatalError(msg)

totalSize = self.popValue()
iterable = self.popValue()
arrTy = cc.PointerType.getElementType(iterable.type)
Expand Down Expand Up @@ -1472,16 +1473,17 @@ def bodyBuilder(iterVar):
complex.CreateOp(self.getComplexType(), real, imag).result)
return

if node.func.id in ["h", "x", "y", "z", "s", "t"]:
if node.func.id in ['h', 'x', 'y', 'z', 's', 't']:
# Here we enable application of the op on all the
# provided arguments, e.g. `x(qubit)`, `x(qvector)`, `x(q, r)`, etc.
numValues = len(self.valueStack)
qubitTargets = [self.popValue() for _ in range(numValues)]
qubitTargets.reverse()
self.checkControlAndTargetTypes([], qubitTargets)
self.__applyQuantumOperation(node.func.id, [], qubitTargets)
return

if node.func.id in ["ch", "cx", "cy", "cz", "cs", "ct"]:
if node.func.id in ['ch', 'cx', 'cy', 'cz', 'cs', 'ct']:
# These are single target controlled quantum operations
MAX_ARGS = 2
numValues = len(self.valueStack)
Expand All @@ -1504,18 +1506,26 @@ def bodyBuilder(iterVar):
negated_qubit_controls=negatedControlQubits)
return

if node.func.id in ["rx", "ry", "rz", "r1"]:
if node.func.id in ['rx', 'ry', 'rz', 'r1']:
numValues = len(self.valueStack)
if numValues < 2:
self.emitFatalError(
f'invalid number of arguments ({numValues}) passed to {node.func.id} (requires at least 2 arguments)',
node)
qubitTargets = [self.popValue() for _ in range(numValues - 1)]
qubitTargets.reverse()
param = self.popValue()
if IntegerType.isinstance(param.type):
param = arith.SIToFPOp(self.getFloatType(), param).result
elif not F64Type.isinstance(param.type):
self.emitFatalError(
'rotational parameter must be a float, or int.', node)
self.checkControlAndTargetTypes([], qubitTargets)
self.__applyQuantumOperation(node.func.id, [param],
qubitTargets)
return

if node.func.id in ["crx", "cry", "crz", "cr1"]:
if node.func.id in ['crx', 'cry', 'crz', 'cr1']:
## These are single target, one parameter, controlled quantum operations
MAX_ARGS = 3
numValues = len(self.valueStack)
Expand All @@ -1529,13 +1539,16 @@ def bodyBuilder(iterVar):
param = self.popValue()
if IntegerType.isinstance(param.type):
param = arith.SIToFPOp(self.getFloatType(), param).result
elif not F64Type.isinstance(param.type):
self.emitFatalError(
'rotational parameter must be a float, or int.', node)
# Map `crx` to `RxOp`...
opCtor = getattr(
quake, '{}Op'.format(node.func.id.title()[1:].capitalize()))
opCtor([], [param], [control], [target])
return

if node.func.id in ["sdg", "tdg"]:
if node.func.id in ['sdg', 'tdg']:
target = self.popValue()
self.checkControlAndTargetTypes([], [target])
# Map `sdg` to `SOp`...
Expand Down Expand Up @@ -1612,6 +1625,7 @@ def bodyBuilder(iterVal):

if node.func.id == 'reset':
target = self.popValue()
self.checkControlAndTargetTypes([], [target])
if quake.RefType.isinstance(target.type):
quake.ResetOp([], target)
return
Expand All @@ -1638,22 +1652,39 @@ def bodyBuilder(iterVal):
all_args = [
self.popValue() for _ in range(len(self.valueStack))
]
if len(all_args) < 4:
self.emitFatalError(
f'invalid number of arguments ({len(all_args)}) passed to {node.func.id} (requires at least 4 arguments)',
node)
qubitTargets = all_args[:-3]
qubitTargets.reverse()
self.checkControlAndTargetTypes([], qubitTargets)
params = all_args[-3:]
params.reverse()
for idx, val in enumerate(params):
if IntegerType.isinstance(val.type):
params[idx] = arith.SIToFPOp(self.getFloatType(),
val).result
elif not F64Type.isinstance(val.type):
self.emitFatalError(
'rotational parameter must be a float, or int.',
node)
self.__applyQuantumOperation(node.func.id, params, qubitTargets)
return

if node.func.id in globalRegisteredOperations:
unitary = globalRegisteredOperations[node.func.id]
numTargets = int(np.log2(np.sqrt(unitary.size)))

numValues = len(self.valueStack)
if numValues != numTargets:
self.emitFatalError(
f'invalid number of arguments ({numValues}) passed to {node.func.id} (requires {numTargets} arguments)',
node)

targets = [self.popValue() for _ in range(numTargets)]
targets.reverse()

self.checkControlAndTargetTypes([], targets)

globalName = f'{nvqppPrefix}{node.func.id}_generator_{numTargets}.rodata'
Expand All @@ -1673,6 +1704,25 @@ def bodyBuilder(iterVal):
is_adj=False)
return

# Handle the case where we are capturing an opaque kernel
# function. It has to be in the capture vars and it has to
# be a PyKernelDecorator.
if node.func.id in self.capturedVars and node.func.id not in globalKernelRegistry:
from .kernel_decorator import PyKernelDecorator
var = self.capturedVars[node.func.id]
if isinstance(var, PyKernelDecorator):
# If we found it, then compile its ASTModule to MLIR so
# that it is in the proper registries, then give it
# the proper function alias
PyASTBridge(var.capturedDataStorage,
existingModule=self.module,
locationOffset=var.location).visit(
var.astModule)
# If we have an alias, make sure we point back to the
# kernel registry correctly for the next conditional check
if var.name in globalKernelRegistry:
node.func.id = var.name

if node.func.id in globalKernelRegistry:
# If in `globalKernelRegistry`, it has to be in this Module
otherKernel = SymbolTable(self.module.operation)[nvqppPrefix +
Expand Down Expand Up @@ -2182,6 +2232,10 @@ def maybeProposeOpAttrFix(opName, attrName):
controls = [
self.popValue() for i in range(len(node.args) - 1)
]
if not controls:
self.emitFatalError(
'controlled operation requested without any control argument(s).',
node)
negatedControlQubits = None
if len(self.controlNegations):
negCtrlBools = [None] * len(controls)
Expand Down Expand Up @@ -2235,6 +2289,10 @@ def bodyBuilder(iterVal):
controls = [
self.popValue() for i in range(len(self.valueStack))
]
if not controls:
self.emitFatalError(
'controlled operation requested without any control argument(s).',
node)
opCtor = getattr(quake,
'{}Op'.format(node.func.value.id.title()))
self.checkControlAndTargetTypes(controls, [targetA, targetB])
Expand All @@ -2249,9 +2307,17 @@ def bodyBuilder(iterVal):
]
param = controls[-1]
controls = controls[:-1]
if not controls:
self.emitFatalError(
'controlled operation requested without any control argument(s).',
node)
if IntegerType.isinstance(param.type):
param = arith.SIToFPOp(self.getFloatType(),
param).result
elif not F64Type.isinstance(param.type):
self.emitFatalError(
'rotational parameter must be a float, or int.',
node)
opCtor = getattr(quake,
'{}Op'.format(node.func.value.id.title()))
self.checkControlAndTargetTypes(controls, [target])
Expand All @@ -2264,6 +2330,10 @@ def bodyBuilder(iterVal):
if IntegerType.isinstance(param.type):
param = arith.SIToFPOp(self.getFloatType(),
param).result
elif not F64Type.isinstance(param.type):
self.emitFatalError(
'rotational parameter must be a float, or int.',
node)
opCtor = getattr(quake,
'{}Op'.format(node.func.value.id.title()))
self.checkControlAndTargetTypes([], [target])
Expand Down Expand Up @@ -2302,13 +2372,20 @@ def bodyBuilder(iterVal):

if node.func.attr == 'ctrl':
controls = other_args[:-3]
if not controls:
self.emitFatalError(
'controlled operation requested without any control argument(s).',
node)
params = other_args[-3:]
params.reverse()
for idx, val in enumerate(params):
if IntegerType.isinstance(val.type):
params[idx] = arith.SIToFPOp(
self.getFloatType(), val).result

elif not F64Type.isinstance(val.type):
self.emitFatalError(
'rotational parameter must be a float, or int.',
node)
negatedControlQubits = None
if len(self.controlNegations):
negCtrlBools = [None] * len(controls)
Expand All @@ -2332,6 +2409,10 @@ def bodyBuilder(iterVal):
if IntegerType.isinstance(val.type):
params[idx] = arith.SIToFPOp(
self.getFloatType(), val).result
elif not F64Type.isinstance(val.type):
self.emitFatalError(
'rotational parameter must be a float, or int.',
node)

self.checkControlAndTargetTypes([], [target])
if quake.VeqType.isinstance(target.type):
Expand Down Expand Up @@ -2385,6 +2466,10 @@ def bodyBuilder(iterVal):
controls = [
self.popValue() for _ in range(numValues - numTargets)
]
if not controls:
self.emitFatalError(
'controlled operation requested without any control argument(s).',
node)
negatedControlQubits = None
if len(self.controlNegations):
negCtrlBools = [None] * len(controls)
Expand Down
20 changes: 20 additions & 0 deletions python/tests/builder/test_kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,26 @@ def test_builder_rotate_state():
assert '10' in counts


@skipIfPythonLessThan39
def test_issue_9():

kernel, features = cudaq.make_kernel(list)
qubits = kernel.qalloc(8)
kernel.rx(features[0], qubits[100])

with pytest.raises(RuntimeError) as error:
kernel([3.14])


def test_issue_670():

kernel = cudaq.make_kernel()
qubits = kernel.qalloc(1)
kernel.ry(0.1, qubits)

cudaq.sample(kernel)


# leave for gdb debugging
if __name__ == "__main__":
loc = os.path.abspath(__file__)
Expand Down
10 changes: 5 additions & 5 deletions python/tests/display/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def kernel():
r1(3.14159, q[0])
tdg(q[1])
s(q[2])
swap.ctrl(q[0], q[2])
swap.ctrl(q[1], q[2])
swap.ctrl(q[0], q[1])
swap.ctrl(q[0], q[2])
swap.ctrl(q[1], q[2])
swap(q[0], q[2])
swap(q[1], q[2])
swap(q[0], q[1])
swap(q[0], q[2])
swap(q[1], q[2])
swap.ctrl(q[3], q[0], q[1])
swap.ctrl(q[0], q[3], q[1], q[2])
swap.ctrl(q[1], q[0], q[3])
Expand Down
Loading

0 comments on commit 892fd9c

Please sign in to comment.