From e21a32b8030cbb71eadb592bd81f75dfb5474721 Mon Sep 17 00:00:00 2001 From: Eric Schweitz Date: Tue, 8 Oct 2024 11:02:37 -0700 Subject: [PATCH] [core] Add quantum reference product type (#2254) * Start on pure quantum struct usage in kernels Signed-off-by: Alex McCaskey * Update the python bindings with new qstruct restrictions Signed-off-by: Alex McCaskey * Enable default parenthesis constructor Signed-off-by: Alex McCaskey * disallow recursive quantum struct Signed-off-by: Alex McCaskey * Implement error handling for various cases in python Signed-off-by: Alex McCaskey * spell fixes Signed-off-by: Alex McCaskey * forgot to filter out __qpu__ methods on structs, those are allowed Signed-off-by: Alex McCaskey * Add new quantum reference type, !quake.struq, and a couple of new operations: quake.make_struq and quake.get_member. These add the utility of having a product quantum reference type (to logically group together sets of qubits) but keep the classical and quantum dialects distinct. Update the tests, python ast bridge, C++ bridge, add codegen patterns, etc. * Whackamole games with the CI. Add roundtrip test on new type and ops. Update the python tests. Also change test to eliminate deprecation warnings. Add invlid IR checks for new operations. Add sanity checks. We do not want to allow a quantum struct that holds anything but non-owning references to qubits or qubit collections. Remove unused folder pattern. Workaround for overly assertive compiler warning. Reenable the hash-and-cache of extract_ref ops in the C++ bridge. This is a dubious optimization that we may actually want to take out at some point, but that should be part of a distinct/different PR. Update test to work around that pytest output can be shuffled. Add case to python for quake.struq type. Another python fix. Add explicit checks to utils.py. Stab in the dark. --------- Signed-off-by: Alex McCaskey Co-authored-by: Alex McCaskey --- include/cudaq/Frontend/nvqpp/ASTBridge.h | 1 + .../cudaq/Optimizer/Dialect/Quake/QuakeOps.td | 48 ++++++ .../Optimizer/Dialect/Quake/QuakeTypes.h | 10 +- .../Optimizer/Dialect/Quake/QuakeTypes.td | 44 ++++- lib/Frontend/nvqpp/ConvertDecl.cpp | 22 ++- lib/Frontend/nvqpp/ConvertExpr.cpp | 81 +++++++-- lib/Frontend/nvqpp/ConvertType.cpp | 97 +++++++++-- lib/Optimizer/CodeGen/ConvertToCC.cpp | 6 + lib/Optimizer/CodeGen/ConvertToQIR.cpp | 7 + lib/Optimizer/CodeGen/QuakeToLLVM.cpp | 67 ++++++-- lib/Optimizer/Dialect/Quake/QuakeOps.cpp | 66 ++++++++ lib/Optimizer/Dialect/Quake/QuakeTypes.cpp | 40 ++++- python/cudaq/kernel/ast_bridge.py | 92 ++++++---- python/cudaq/kernel/utils.py | 59 ++++++- python/runtime/mlir/py_register_dialects.cpp | 44 +++++ python/tests/kernel/test_kernel_features.py | 120 +++++++++++-- python/tests/mlir/quantum_struct.py | 39 +++++ python/tests/mlir/quantum_type.py | 59 +++---- python/tests/mlir/test_output_qir.py | 4 +- targettests/execution/quantum_struct.cpp | 70 ++++++++ test/AST-Quake/pure_quantum_struct.cpp | 157 ++++++++++++++++++ ...ment-2.cpp => kernel_invalid_argument.cpp} | 0 .../kernel_with_member_functions.cpp | 23 +++ .../AST-error/quantum_struct_declarations.cpp | 44 +++++ test/AST-error/quantum_struct_signature.cpp | 23 +++ .../quantum_struct_with_struct_member.cpp | 29 ++++ .../struct_quantum_and_classical.cpp | 34 ++++ test/Quake/invalid.qke | 56 +++++++ test/Quake/roundtrip-ops.qke | 38 +++++ 29 files changed, 1242 insertions(+), 138 deletions(-) create mode 100644 python/tests/mlir/quantum_struct.py create mode 100644 targettests/execution/quantum_struct.cpp create mode 100644 test/AST-Quake/pure_quantum_struct.cpp rename test/AST-error/{kernel_invalid_argument-2.cpp => kernel_invalid_argument.cpp} (100%) create mode 100644 test/AST-error/kernel_with_member_functions.cpp create mode 100644 test/AST-error/quantum_struct_declarations.cpp create mode 100644 test/AST-error/quantum_struct_signature.cpp create mode 100644 test/AST-error/quantum_struct_with_struct_member.cpp create mode 100644 test/AST-error/struct_quantum_and_classical.cpp create mode 100644 test/Quake/invalid.qke diff --git a/include/cudaq/Frontend/nvqpp/ASTBridge.h b/include/cudaq/Frontend/nvqpp/ASTBridge.h index e6e6d28aee..baf4518fce 100644 --- a/include/cudaq/Frontend/nvqpp/ASTBridge.h +++ b/include/cudaq/Frontend/nvqpp/ASTBridge.h @@ -286,6 +286,7 @@ class QuakeBridgeVisitor DataRecursionQueue *q = nullptr); bool VisitCXXConstructExpr(clang::CXXConstructExpr *x); bool VisitCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x); + bool VisitCXXParenListInitExpr(clang::CXXParenListInitExpr *x); bool WalkUpFromCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x); bool TraverseDeclRefExpr(clang::DeclRefExpr *x, DataRecursionQueue *q = nullptr); diff --git a/include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td b/include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td index 3d14bad6cd..ea1681e340 100644 --- a/include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td +++ b/include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td @@ -641,6 +641,54 @@ def quake_ReturnWireOp : QuakeOp<"return_wire"> { let assemblyFormat = "$target `:` type(operands) attr-dict"; } +//===----------------------------------------------------------------------===// +// Struq handling +//===----------------------------------------------------------------------===// + +def quake_MakeStruqOp : QuakeOp<"make_struq", [Pure]> { + let summary = "create a quantum struct from a set of quantum references"; + let description = [{ + Given a list of values of quantum reference type, creates a new quantum + product reference type. This is a logical grouping and does not imply any + new quantum references are created. + + This operation can be useful for grouping a number of values of type `veq` + into a logical product type, which may be passed to a pure device kernel + as a single unit, for example. These product types may always be erased into + a vector of the quantum references used to compose them via a make_struq op. + }]; + + let arguments = (ins Variadic:$veqs); + let results = (outs StruqType); + let hasVerifier = 1; + + let assemblyFormat = [{ + $veqs `:` functional-type(operands, results) attr-dict + }]; +} + +def quake_GetMemberOp : QuakeOp<"get_member", [Pure]> { + let summary = "extract quantum references from a quantum struct"; + let description = [{ + The get_member operation can be used to extract a set of quantum references + from a quantum struct (product) type. The fields in the quantum struct are + indexed from 0 to $n-1$ where $n$ is the number of fields. An index outside + of this range will produce a verification error. + }]; + + let arguments = (ins + StruqType:$struq, + I32Attr:$index + ); + let results = (outs NonStruqRefType); + let hasCanonicalizer = 1; + let hasVerifier = 1; + + let assemblyFormat = [{ + $struq `[` $index `]` `:` functional-type(operands, results) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // ToControl, FromControl pair //===----------------------------------------------------------------------===// diff --git a/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.h b/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.h index 7b62009b15..c14e22e838 100644 --- a/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.h +++ b/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.h @@ -25,7 +25,7 @@ namespace quake { inline bool isQuantumType(mlir::Type ty) { // NB: this intentionally excludes MeasureType. return mlir::isa(ty); + quake::ControlType, quake::StruqType>(ty); } /// \returns true if \p `ty` is a Quake type. @@ -34,10 +34,16 @@ inline bool isQuakeType(mlir::Type ty) { return isQuantumType(ty) || mlir::isa(ty); } -inline bool isQuantumReferenceType(mlir::Type ty) { +/// \returns true if \p ty is a quantum reference type, excluding `struq`. +inline bool isNonStruqReferenceType(mlir::Type ty) { return mlir::isa(ty); } +/// \returns true if \p ty is a quantum reference type. +inline bool isQuantumReferenceType(mlir::Type ty) { + return isNonStruqReferenceType(ty) || mlir::isa(ty); +} + /// A quake wire type is a linear type. inline bool isLinearType(mlir::Type ty) { return mlir::isa(ty); diff --git a/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td b/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td index a57c3c27bc..2e6fc9dfe3 100644 --- a/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td +++ b/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td @@ -161,6 +161,41 @@ def VeqType : QuakeType<"Veq", "veq"> { }]; } +//===----------------------------------------------------------------------===// +// StruqType: quantum reference type; product of veq and ref types. +//===----------------------------------------------------------------------===// + +def StruqType : QuakeType<"Struq", "struq"> { + let summary = "a product type of quantum references"; + let description = [{ + This type allows one to group veqs of quantum references together in a + single product type. + + To support Python, a struq type can be assigned a name. This allows the + python bridge to perform dictionary based lookups on member field names. + }]; + + let parameters = (ins + "mlir::StringAttr":$name, + ArrayRefParameter<"mlir::Type">:$members + ); + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + std::size_t getNumMembers() const { return getMembers().size(); } + }]; + + let builders = [ + TypeBuilder<(ins CArg<"llvm::ArrayRef">:$members), [{ + return $_get($_ctxt, mlir::StringAttr{}, members); + }]>, + TypeBuilder<(ins CArg<"llvm::StringRef">:$name, + CArg<"llvm::ArrayRef">:$members), [{ + return $_get($_ctxt, mlir::StringAttr::get($_ctxt, name), members); + }]> + ]; +} + //===----------------------------------------------------------------------===// // MeasureType: classical data type //===----------------------------------------------------------------------===// @@ -183,14 +218,19 @@ def MeasureType : QuakeType<"Measure", "measure"> { } def AnyQTypeLike : TypeConstraint, "quake quantum types">; + ControlType.predicate, RefType.predicate, StruqType.predicate]>, + "quake quantum types">; def AnyQType : Type; def AnyQTargetTypeLike : TypeConstraint, "quake quantum target types">; def AnyQTargetType : Type; -def AnyRefTypeLike : TypeConstraint, "quake quantum reference types">; def AnyRefType : Type; +def NonStruqRefTypeLike : TypeConstraint, "non-struct quake quantum reference types">; +def NonStruqRefType : Type; def AnyQValueTypeLike : TypeConstraint, "quake quantum value types">; def AnyQValueType : Type; diff --git a/lib/Frontend/nvqpp/ConvertDecl.cpp b/lib/Frontend/nvqpp/ConvertDecl.cpp index 7d277197a8..0c74b0b0da 100644 --- a/lib/Frontend/nvqpp/ConvertDecl.cpp +++ b/lib/Frontend/nvqpp/ConvertDecl.cpp @@ -93,8 +93,8 @@ void QuakeBridgeVisitor::addArgumentSymbols( auto parmTy = entryBlock->getArgument(index).getType(); if (isa(parmTy)) { + quake::ControlType, quake::RefType, quake::StruqType, + quake::VeqType, quake::WireType>(parmTy)) { symbolTable.insert(name, entryBlock->getArgument(index)); } else { auto stackSlot = builder.create(loc, parmTy); @@ -658,9 +658,8 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) { if (auto qType = dyn_cast(type)) { // Variable is of !quake.ref type. if (x->hasInit() && !valueStack.empty()) { - auto val = popValue(); - symbolTable.insert(name, val); - return pushValue(val); + symbolTable.insert(name, peekValue()); + return true; } auto zero = builder.create( loc, 0, builder.getIntegerType(64)); @@ -672,6 +671,13 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) { return pushValue(addressTheQubit); } + if (isa(type)) { + // A pure quantum struct is just passed along by value. It cannot be stored + // to a variable. + symbolTable.insert(name, peekValue()); + return true; + } + // Here we maybe have something like auto var = mz(qreg) if (auto vecType = dyn_cast(type)) { // Variable is of !cc.stdvec type. @@ -805,6 +811,12 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) { return pushValue(cast); } + // Don't allocate memory for a quantum or value-semantic struct. + if (auto insertValOp = initValue.getDefiningOp()) { + symbolTable.insert(x->getName(), initValue); + return pushValue(initValue); + } + // Initialization expression resulted in a value. Create a variable and save // that value to the variable's memory address. Value alloca = builder.create(loc, type); diff --git a/lib/Frontend/nvqpp/ConvertExpr.cpp b/lib/Frontend/nvqpp/ConvertExpr.cpp index d05b74af1e..5976e1355d 100644 --- a/lib/Frontend/nvqpp/ConvertExpr.cpp +++ b/lib/Frontend/nvqpp/ConvertExpr.cpp @@ -1109,14 +1109,23 @@ bool QuakeBridgeVisitor::VisitMemberExpr(clang::MemberExpr *x) { if (auto *field = dyn_cast(x->getMemberDecl())) { auto loc = toLocation(x->getSourceRange()); auto object = popValue(); // DeclRefExpr + auto ty = popType(); + std::int32_t offset = field->getFieldIndex(); + if (isa(object.getType())) { + return pushValue( + builder.create(loc, ty, object, offset)); + } + if (!isa(object.getType())) { + reportClangError(x, mangler, + "internal error: struct must be an object in memory"); + return false; + } auto eleTy = cast(object.getType()).getElementType(); SmallVector offsets; if (auto arrTy = dyn_cast(eleTy)) if (arrTy.isUnknownSize()) offsets.push_back(0); - std::int32_t offset = field->getFieldIndex(); offsets.push_back(offset); - auto ty = popType(); return pushValue(builder.create( loc, cc::PointerType::get(ty), object, offsets)); } @@ -2199,11 +2208,26 @@ bool QuakeBridgeVisitor::VisitCXXOperatorCallExpr( if (isCudaQType(typeName)) { auto idx_var = popValue(); auto qreg_var = popValue(); - + auto *arg0 = x->getArg(0); + if (isa(arg0)) { + // This is a subscript operator on a data member and the type is a + // quantum type (likely a `qview`). This can only happen in a quantum + // `struct`, which the spec says must be one-level deep at most and must + // only contain references to qubits explicitly allocated in other + // variables. `qreg_var` will be a `quake.get_member`. Do not add this + // extract `Op` to the symbol table, but always generate a new + // `quake.extract_ref` `Op` to get the exact qubit (reference) value. + auto address_qubit = + builder.create(loc, qreg_var, idx_var); + return replaceTOSValue(address_qubit); + } // Get name of the qreg, e.g. qr, and use it to construct a name for the // element, which is intended to be qr%n when n is the index of the // accessed qubit. - StringRef qregName = getNamedDecl(x->getArg(0))->getName(); + if (!isa(arg0)) + reportClangError(x, mangler, + "internal error: expected a variable name"); + StringRef qregName = getNamedDecl(arg0)->getName(); auto name = getQubitSymbolTableName(qregName, idx_var); char *varName = strdup(name.c_str()); @@ -2211,12 +2235,15 @@ bool QuakeBridgeVisitor::VisitCXXOperatorCallExpr( if (symbolTable.count(name)) return replaceTOSValue(symbolTable.lookup(name)); - // Otherwise create an operation to access the qubit, store that value in - // the symbol table, and return the AddressQubit operation's resulting - // value. + // Otherwise create an operation to access the qubit, store that value + // in the symbol table, and return the AddressQubit operation's + // resulting value. auto address_qubit = builder.create(loc, qreg_var, idx_var); + // NB: varName is built from the variable name *and* the index value. This + // front-end optimization is likely unnecessary as the compiler can always + // canonicalize and merge identical quake.extract_ref operations. symbolTable.insert(StringRef(varName), address_qubit); return replaceTOSValue(address_qubit); } @@ -2395,7 +2422,10 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) { bool allRef = std::all_of(last.begin(), last.end(), [](auto v) { return isa(v.getType()); }); - if (allRef) { + if (allRef && isa(initListTy)) + return pushValue(builder.create(loc, initListTy, last)); + + if (allRef && !isa(initListTy)) { // Initializer list contains all quantum reference types. In this case we // want to create quake code to concatenate the references into a veq. if (size > 1) { @@ -2466,6 +2496,11 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) { auto globalInit = builder.create(loc, ptrTy, name); return pushValue(globalInit); } + + // If quantum, use value semantics with cc insert / extract value. + if (isa(eleTy)) + return pushValue(builder.create(loc, eleTy, last)); + Value alloca = (numEles > 1) ? builder.create(loc, eleTy, arrSize) : builder.create(loc, eleTy); @@ -2556,6 +2591,19 @@ static Type getEleTyFromVectorCtor(Type ctorTy) { return ctorTy; } +bool QuakeBridgeVisitor::VisitCXXParenListInitExpr( + clang::CXXParenListInitExpr *x) { + auto ty = peekType(); + assert(ty && "type must be present"); + LLVM_DEBUG(llvm::dbgs() << "paren list type: " << ty << '\n'); + auto structTy = dyn_cast(ty); + if (!structTy) + return true; + auto loc = toLocation(x); + auto last = lastValues(structTy.getMembers().size()); + return pushValue(builder.create(loc, structTy, last)); +} + bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) { auto loc = toLocation(x); auto *ctor = x->getConstructor(); @@ -2855,12 +2903,17 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) { return true; } - if (ctor->isCopyOrMoveConstructor() && parent->isPOD()) { - // Copy or move constructor on a POD struct. The value stack should contain - // the object to load the value from. - auto fromStruct = popValue(); - assert(isa(ctorTy) && "POD must be a struct type"); - return pushValue(builder.create(loc, fromStruct)); + if (ctor->isCopyOrMoveConstructor()) { + // Just walk through copy constructors for quantum struct types. + if (isa(ctorTy)) + return true; + if (parent->isPOD()) { + // Copy or move constructor on a POD struct. The value stack should + // contain the object to load the value from. + auto fromStruct = popValue(); + assert(isa(ctorTy) && "POD must be a struct type"); + return pushValue(builder.create(loc, fromStruct)); + } } if (ctor->isCopyConstructor() && ctor->isTrivial() && diff --git a/lib/Frontend/nvqpp/ConvertType.cpp b/lib/Frontend/nvqpp/ConvertType.cpp index 91b79b040b..6accee6bac 100644 --- a/lib/Frontend/nvqpp/ConvertType.cpp +++ b/lib/Frontend/nvqpp/ConvertType.cpp @@ -22,12 +22,6 @@ static bool isArithmeticType(Type t) { return isa(t); } -/// Is \p t a quantum reference type. In the bridge, quantum types are always -/// reference types. -static bool isQuantumType(Type t) { - return isa(t); -} - /// Allow `array of [array of]* T`, where `T` is arithmetic. static bool isStaticArithmeticSequenceType(Type t) { if (auto vec = dyn_cast(t)) { @@ -144,7 +138,8 @@ static bool isKernelResultType(Type t) { /// (function), or a string. static bool isKernelArgumentType(Type t) { return isArithmeticType(t) || isComposedArithmeticType(t) || - isQuantumType(t) || isKernelCallable(t) || isFunctionCallable(t) || + quake::isQuantumReferenceType(t) || isKernelCallable(t) || + isFunctionCallable(t) || // TODO: move from pointers to a builtin string type. cudaq::isCharPointerType(t); } @@ -243,13 +238,93 @@ bool QuakeBridgeVisitor::VisitRecordDecl(clang::RecordDecl *x) { auto *ctx = builder.getContext(); if (!x->getDefinition()) return pushType(cc::StructType::get(ctx, name, /*isOpaque=*/true)); + SmallVector fieldTys = lastTypes(std::distance(x->field_begin(), x->field_end())); auto [width, alignInBytes] = getWidthAndAlignment(x); - if (name.empty()) - return pushType(cc::StructType::get(ctx, fieldTys, width, alignInBytes)); - return pushType( - cc::StructType::get(ctx, name, fieldTys, width, alignInBytes)); + bool isStruq = !fieldTys.empty(); + for (auto ty : fieldTys) + if (!quake::isQuantumReferenceType(ty)) + isStruq = false; + + auto ty = [&]() -> Type { + if (isStruq) + return quake::StruqType::get(ctx, fieldTys); + if (name.empty()) + return cc::StructType::get(ctx, fieldTys, width, alignInBytes); + return cc::StructType::get(ctx, name, fieldTys, width, alignInBytes); + }(); + + // Do some error analysis on the product type. Check the following: + + // - If this is a struq: + if (isa(ty)) { + // -- does it contain invalid C++ types? + for (auto *field : x->fields()) { + auto *ty = field->getType().getTypePtr(); + bool isRef = false; + if (ty->isLValueReferenceType()) { + auto *lref = cast(ty); + isRef = true; + ty = lref->getPointeeType().getTypePtr(); + } + if (auto *tyDecl = ty->getAsRecordDecl()) { + if (auto *ident = tyDecl->getIdentifier()) { + auto name = ident->getName(); + if (isInNamespace(tyDecl, "cudaq")) { + if (isRef) { + // can be owning container; so can be qubit, qarray, or qvector + if ((name.equals("qudit") || name.equals("qubit") || + name.equals("qvector") || name.equals("qarray"))) + continue; + } + // must be qview or qview& + if (name.equals("qview")) + continue; + } + } + } + reportClangError(x, mangler, "quantum struct has invalid member type."); + } + // -- does it contain contain a struq member? Not allowed. + for (auto fieldTy : fieldTys) + if (isa(fieldTy)) + reportClangError(x, mangler, + "recursive quantum struct types are not allowed."); + } + + // - Is this a struct does it have quantum types? Not allowed. + if (!isa(ty)) + for (auto fieldTy : fieldTys) + if (quake::isQuakeType(fieldTy)) + reportClangError( + x, mangler, + "hybrid quantum-classical struct types are not allowed."); + + // - Does this product type have (user-defined) member functions? Not allowed. + if (auto *cxxRd = dyn_cast(x)) { + auto numMethods = [&cxxRd]() { + std::size_t count = 0; + for (auto methodIter = cxxRd->method_begin(); + methodIter != cxxRd->method_end(); ++methodIter) { + // Don't check if this is a __qpu__ struct method + if (auto attr = (*methodIter)->getAttr(); + attr && attr->getAnnotation().str() == cudaq::kernelAnnotation) + continue; + // Check if the method is not implicit (i.e., user-defined) + if (!(*methodIter)->isImplicit()) + count++; + } + return count; + }(); + + if (numMethods > 0) + reportClangError( + x, mangler, + "struct with user-defined methods is not allowed in quantum kernel."); + } + + return pushType(ty); } bool QuakeBridgeVisitor::VisitFunctionProtoType(clang::FunctionProtoType *t) { diff --git a/lib/Optimizer/CodeGen/ConvertToCC.cpp b/lib/Optimizer/CodeGen/ConvertToCC.cpp index ca8ce1b2cd..fc8972691a 100644 --- a/lib/Optimizer/CodeGen/ConvertToCC.cpp +++ b/lib/Optimizer/CodeGen/ConvertToCC.cpp @@ -42,6 +42,12 @@ struct QuakeTypeConverter : public TypeConverter { return cudaq::cc::PointerType::get( cudaq::opt::getCudaqQubitSpanType(ty.getContext())); }); + addConversion([&](quake::StruqType ty) { + SmallVector mems; + for (auto m : ty.getMembers()) + mems.push_back(convertType(m)); + return cudaq::cc::StructType::get(ty.getContext(), mems); + }); addConversion([](quake::MeasureType ty) { return IntegerType::get(ty.getContext(), 64); }); diff --git a/lib/Optimizer/CodeGen/ConvertToQIR.cpp b/lib/Optimizer/CodeGen/ConvertToQIR.cpp index 3d6690ad6e..c5b4606e2d 100644 --- a/lib/Optimizer/CodeGen/ConvertToQIR.cpp +++ b/lib/Optimizer/CodeGen/ConvertToQIR.cpp @@ -198,6 +198,13 @@ void cudaq::opt::initializeTypeConversions(LLVMTypeConverter &typeConverter) { [](quake::VeqType type) { return getArrayType(type.getContext()); }); typeConverter.addConversion( [](quake::RefType type) { return getQubitType(type.getContext()); }); + typeConverter.addConversion([&](quake::StruqType type) { + SmallVector mems; + for (auto m : type.getMembers()) + mems.push_back(typeConverter.convertType(m)); + return LLVM::LLVMStructType::getLiteral(type.getContext(), mems, + /*packed=*/false); + }); typeConverter.addConversion([](quake::MeasureType type) { return IntegerType::get(type.getContext(), 1); }); diff --git a/lib/Optimizer/CodeGen/QuakeToLLVM.cpp b/lib/Optimizer/CodeGen/QuakeToLLVM.cpp index 5686f18e16..30d2981ac4 100644 --- a/lib/Optimizer/CodeGen/QuakeToLLVM.cpp +++ b/lib/Optimizer/CodeGen/QuakeToLLVM.cpp @@ -336,6 +336,43 @@ class ExtractQubitOpRewrite } }; +class GetMemberOpPattern : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(quake::GetMemberOp extract, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto toTy = getTypeConverter()->convertType(extract.getType()); + std::int64_t position = adaptor.getIndex(); + rewriter.replaceOpWithNewOp( + extract, toTy, adaptor.getStruq(), ArrayRef{position}); + return success(); + } +}; + +class MakeStruqOpPattern : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(quake::MakeStruqOp mkStruq, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = mkStruq.getLoc(); + auto *ctx = rewriter.getContext(); + auto toTy = getTypeConverter()->convertType(mkStruq.getType()); + Value result = rewriter.create(loc, toTy); + std::int64_t count = 0; + for (auto op : adaptor.getOperands()) { + auto off = DenseI64ArrayAttr::get(ctx, ArrayRef{count}); + result = rewriter.create(loc, toTy, result, op, off); + count++; + } + rewriter.replaceOp(mkStruq, result); + return success(); + } +}; + class SubveqOpRewrite : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1386,19 +1423,21 @@ void cudaq::opt::populateQuakeToLLVMPatterns(LLVMTypeConverter &typeConverter, auto *context = patterns.getContext(); patterns.insert(context); - patterns.insert< - AllocaOpRewrite, ConcatOpRewrite, CustomUnitaryOpRewrite, - DeallocOpRewrite, DiscriminateOpPattern, ExtractQubitOpRewrite, - ExpPauliRewrite, OneTargetRewrite, - OneTargetRewrite, OneTargetRewrite, - OneTargetRewrite, OneTargetRewrite, - OneTargetRewrite, OneTargetOneParamRewrite, - OneTargetTwoParamRewrite, - OneTargetOneParamRewrite, - OneTargetOneParamRewrite, - OneTargetOneParamRewrite, - OneTargetTwoParamRewrite, - OneTargetThreeParamRewrite, QmemRAIIOpRewrite, ResetRewrite, - SubveqOpRewrite, TwoTargetRewrite>(typeConverter); + patterns + .insert, OneTargetRewrite, + OneTargetRewrite, OneTargetRewrite, + OneTargetRewrite, OneTargetRewrite, + OneTargetOneParamRewrite, + OneTargetTwoParamRewrite, + OneTargetOneParamRewrite, + OneTargetOneParamRewrite, + OneTargetOneParamRewrite, + OneTargetTwoParamRewrite, + OneTargetThreeParamRewrite, QmemRAIIOpRewrite, + ResetRewrite, SubveqOpRewrite, TwoTargetRewrite>( + typeConverter); patterns.insert>(typeConverter, measureCounter); } diff --git a/lib/Optimizer/Dialect/Quake/QuakeOps.cpp b/lib/Optimizer/Dialect/Quake/QuakeOps.cpp index 328c993ae8..e3d9ab13d1 100644 --- a/lib/Optimizer/Dialect/Quake/QuakeOps.cpp +++ b/lib/Optimizer/Dialect/Quake/QuakeOps.cpp @@ -497,6 +497,52 @@ LogicalResult quake::ExtractRefOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// GetMemberOp +//===----------------------------------------------------------------------===// + +LogicalResult quake::GetMemberOp::verify() { + std::uint32_t index = getIndex(); + auto strTy = cast(getStruq().getType()); + std::uint32_t size = strTy.getNumMembers(); + if (index >= size) + return emitOpError("invalid index [" + std::to_string(index) + + "] because >= size [" + std::to_string(size) + "]"); + if (getType() != strTy.getMembers()[index]) + return emitOpError("result type does not match member " + + std::to_string(index) + " type"); + return success(); +} + +namespace { +struct BypassMakeStruq : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::GetMemberOp getMem, + PatternRewriter &rewriter) const override { + if (auto makeStruq = + getMem.getStruq().getDefiningOp()) { + auto toStrTy = cast(getMem.getStruq().getType()); + std::uint32_t idx = getMem.getIndex(); + Value from = makeStruq.getOperand(idx); + auto toTy = toStrTy.getMembers()[idx]; + if (from.getType() != toTy) { + rewriter.replaceOpWithNewOp(getMem, toTy, from); + } else { + rewriter.replaceOp(getMem, from); + } + return success(); + } + return failure(); + } +}; +} // namespace + +void quake::GetMemberOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // InitializeStateOp //===----------------------------------------------------------------------===// @@ -567,6 +613,26 @@ void quake::InitializeStateOp::getCanonicalizationPatterns( patterns.add(context); } +//===----------------------------------------------------------------------===// +// MakeStruqOp +//===----------------------------------------------------------------------===// + +LogicalResult quake::MakeStruqOp::verify() { + if (getType().getNumMembers() != getNumOperands()) + return emitOpError("result type has different member count than operands"); + for (auto [ty, opnd] : llvm::zip(getType().getMembers(), getOperands())) { + if (ty == opnd.getType()) + continue; + auto veqTy = dyn_cast(ty); + auto veqOpndTy = dyn_cast(opnd.getType()); + if (veqTy && !veqTy.hasSpecifiedSize() && veqOpndTy && + veqOpndTy.hasSpecifiedSize()) + continue; + return emitOpError("member type not compatible with operand type"); + } + return success(); +} + //===----------------------------------------------------------------------===// // RelaxSizeOp //===----------------------------------------------------------------------===// diff --git a/lib/Optimizer/Dialect/Quake/QuakeTypes.cpp b/lib/Optimizer/Dialect/Quake/QuakeTypes.cpp index 606baed123..959a869fe5 100644 --- a/lib/Optimizer/Dialect/Quake/QuakeTypes.cpp +++ b/lib/Optimizer/Dialect/Quake/QuakeTypes.cpp @@ -39,7 +39,7 @@ void quake::VeqType::print(AsmPrinter &os) const { Type quake::VeqType::parse(AsmParser &parser) { if (parser.parseLess()) return {}; - std::size_t size; + std::size_t size = 0; if (succeeded(parser.parseOptionalQuestion())) size = 0; else if (parser.parseInteger(size)) @@ -58,6 +58,42 @@ quake::VeqType::verify(llvm::function_ref emitError, //===----------------------------------------------------------------------===// +Type quake::StruqType::parse(AsmParser &parser) { + if (parser.parseLess()) + return {}; + std::string name; + auto *ctx = parser.getContext(); + StringAttr nameAttr; + if (succeeded(parser.parseOptionalString(&name))) { + nameAttr = StringAttr::get(ctx, name); + if (parser.parseColon()) + return {}; + } + SmallVector members; + do { + Type member; + auto optTy = parser.parseOptionalType(member); + if (!optTy.has_value()) + break; + if (!succeeded(*optTy)) + return {}; + members.push_back(member); + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseGreater()) + return {}; + return quake::StruqType::get(ctx, nameAttr, members); +} + +void quake::StruqType::print(AsmPrinter &printer) const { + printer << '<'; + if (getName()) + printer << getName() << ": "; + llvm::interleaveComma(getMembers(), printer); + printer << '>'; +} + +//===----------------------------------------------------------------------===// + void quake::QuakeDialect::registerTypes() { - addTypes(); + addTypes(); } diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index fa50edaba3..6597b38f5d 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -251,7 +251,8 @@ def isQuantumType(self, ty): Return True if the given type is quantum (is a `VeqType` or `RefType`). Return False otherwise. """ - return quake.RefType.isinstance(ty) or quake.VeqType.isinstance(ty) + return quake.RefType.isinstance(ty) or quake.VeqType.isinstance( + ty) or quake.StruqType.isinstance(ty) def isMeasureResultType(self, ty, value): """ @@ -526,7 +527,10 @@ def getStructMemberIdx(self, memberName, structTy): the index of the variable in the struct and the specific MLIR type for the variable. """ - structName = cc.StructType.getName(structTy) + if cc.StructType.isinstance(structTy): + structName = cc.StructType.getName(structTy) + else: + structName = quake.StruqType.getName(structTy) structIdx = None _, userType = globalRegisteredTypes[structName] for i, (k, _) in enumerate(userType.items()): @@ -665,18 +669,11 @@ def convertArithmeticToSuperiorType(self, values, type): return retValues - def isQuantumStructType(self, structTy): + def isQuantumStructType(self, ty): """ - Return True if the given struct type has one or more quantum member variables. + Return True if the given struct type has only quantum member variables. """ - if not cc.StructType.isinstance(structTy): - self.emitFatalError( - f'isQuantumStructType called on type that is not a struct ({structTy})' - ) - - return True in [ - self.isQuantumType(t) for t in cc.StructType.getTypes(structTy) - ] + return quake.StruqType.isinstance(ty) def mlirTypeFromAnnotation(self, annotation): """ @@ -843,9 +840,6 @@ def needsStackSlot(self, type): function. """ # FIXME add more as we need them - if cc.StructType.isinstance(type) and self.isQuantumStructType(type): - # If we have a quantum struct, we don't want to add a stack slot - return False return ComplexType.isinstance(type) or F64Type.isinstance( type) or F32Type.isinstance(type) or IntegerType.isinstance( type) or cc.StructType.isinstance(type) @@ -927,7 +921,7 @@ def visit_FunctionDef(self, node): # Set this kernel as an entry point if the argument types are classical only def isQuantumTy(ty): return quake.RefType.isinstance(ty) or quake.VeqType.isinstance( - ty) + ty) or quake.StruqType.isinstance(ty) areQuantumTypes = [isQuantumTy(ty) for ty in self.argTypes] if True not in areQuantumTypes and not self.disableEntryPointTag: @@ -1179,17 +1173,16 @@ def visit_Attribute(self, node): if isinstance(node.value, ast.Name) and node.value.id in self.symbolTable: value = self.symbolTable[node.value.id] - if cc.StructType.isinstance( - value.type) and self.isQuantumStructType(value.type): + if self.isQuantumStructType(value.type): # Here we have a quantum struct, need to use extract value instead # of load from compute pointer. structIdx, memberTy = self.getStructMemberIdx( node.attr, value.type) self.pushValue( - cc.ExtractValueOp( - memberTy, value, [], - DenseI32ArrayAttr.get([structIdx], - context=self.ctx)).result) + quake.GetMemberOp( + memberTy, value, + IntegerAttr.get(self.getIntegerType(32), + structIdx)).result) return if cc.PointerType.isinstance(value.type): @@ -1903,24 +1896,51 @@ def bodyBuilder(iterVal): mlirTypeFromPyType(v, self.ctx) for _, v in annotations.items() ] - structTy = cc.StructType.getNamed(self.ctx, node.func.id, - structTys) + # Ensure we don't use hybrid data types + numQuantumMemberTys = sum( + [1 if self.isQuantumType(ty) else 0 for ty in structTys]) + if numQuantumMemberTys != 0: # we have quantum member types + if numQuantumMemberTys != len(structTys): + self.emitFatalError( + f'hybrid quantum-classical data types not allowed in kernel code', + node) + + isStruq = not (not structTys) + for fieldTy in structTys: + if not self.isQuantumType(fieldTy): + isStruq = False + if isStruq: + structTy = quake.StruqType.getNamed(self.ctx, node.func.id, + structTys) + # Disallow recursive quantum struct types. + for fieldTy in structTys: + if self.isQuantumStructType(fieldTy): + self.emitFatalError( + 'recursive quantum struct types not allowed.', + node) + else: + structTy = cc.StructType.getNamed(self.ctx, node.func.id, + structTys) + + # Disallow user specified methods on structs + if len({ + k: v + for k, v in cls.__dict__.items() + if not (k.startswith('__') and k.endswith('__')) + }) != 0: + self.emitFatalError( + 'struct types with user specified methods are not allowed.', + node) + nArgs = len(self.valueStack) ctorArgs = [self.popValue() for _ in range(nArgs)] ctorArgs.reverse() - if self.isQuantumStructType(structTy): - # If we have a struct with quantum types, we do not - # want to allocate struct memory and load / store pointers - # to quantum memory, so we'll instead use value semantics - # with InsertValue - undefOp = cc.UndefOp(structTy).result - for i, arg in enumerate(ctorArgs): - undefOp = cc.InsertValueOp( - structTy, undefOp, arg, - DenseI64ArrayAttr.get([i], context=self.ctx)).result - - self.pushValue(undefOp) + if isStruq: + # If we have a quantum struct. We cannot allocate classical + # memory and load / store quantum type values to that memory + # space, so use `quake.MakeStruqOp`. + self.pushValue(quake.MakeStruqOp(structTy, ctorArgs).result) return stackSlot = cc.AllocaOp(cc.PointerType.get(self.ctx, structTy), diff --git a/python/cudaq/kernel/utils.py b/python/cudaq/kernel/utils.py index 99d9705ea1..f3c0f1e52b 100644 --- a/python/cudaq/kernel/utils.py +++ b/python/cudaq/kernel/utils.py @@ -213,8 +213,37 @@ def emitFatalErrorOverride(msg): # One final check to see if this is a custom data type. if id in globalRegisteredTypes: - _, memberTys = globalRegisteredTypes[id] + pyType, memberTys = globalRegisteredTypes[id] structTys = [mlirTypeFromPyType(v, ctx) for _, v in memberTys.items()] + for ty in structTys: + if cc.StructType.isinstance(ty): + localEmitFatalError( + 'recursive struct types are not allowed in kernels.') + + if len({ + k: v + for k, v in pyType.__dict__.items() + if not (k.startswith('__') and k.endswith('__')) + }) != 0: + localEmitFatalError( + 'struct types with user specified methods are not allowed.') + + numQuantumMemberTys = sum([ + 1 if + (quake.RefType.isinstance(ty) or quake.VeqType.isinstance(ty) or + quake.StruqType.isinstance(ty)) else 0 for ty in structTys + ]) + numStruqMemberTys = sum( + [1 if (quake.StruqType.isinstance(ty)) else 0 for ty in structTys]) + if numQuantumMemberTys != 0: # we have quantum member types + if numQuantumMemberTys != len(structTys): + emitFatalError( + f'hybrid quantum-classical data types not allowed in kernel code.' + ) + if numStruqMemberTys != 0: + emitFatalError(f'recursive quantum struct types not allowed.') + return quake.StruqType.getNamed(ctx, id, structTys) + return cc.StructType.getNamed(ctx, id, structTys) localEmitFatalError( @@ -320,19 +349,37 @@ def mlirTypeFromPyType(argType, ctx, **kwargs): argInstance = kwargs['argInstance'] if isinstance(argInstance, Callable): return cc.CallableType.get(ctx, argInstance.argTypes) - else: - if argType == list[int]: - return cc.StdvecType.get(ctx, mlirTypeFromPyType(int, ctx)) - if argType == list[float]: - return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx)) for name, (customTys, memberTys) in globalRegisteredTypes.items(): if argType == customTys: structTys = [ mlirTypeFromPyType(v, ctx) for _, v in memberTys.items() ] + numQuantumMemberTys = sum([ + 1 if + (quake.RefType.isinstance(ty) or quake.VeqType.isinstance(ty) or + quake.StruqType.isinstance(ty)) else 0 for ty in structTys + ]) + numStruqMemberTys = sum([ + 1 if (quake.StruqType.isinstance(ty)) else 0 for ty in structTys + ]) + if numQuantumMemberTys != 0: # we have quantum member types + if numQuantumMemberTys != len(structTys): + emitFatalError( + f'hybrid quantum-classical data types not allowed') + if numStruqMemberTys != 0: + emitFatalError( + f'recursive quantum struct types not allowed.') + return quake.StruqType.getNamed(ctx, name, structTys) + return cc.StructType.getNamed(ctx, name, structTys) + if 'argInstance' not in kwargs: + if argType == list[int]: + return cc.StdvecType.get(ctx, mlirTypeFromPyType(int, ctx)) + if argType == list[float]: + return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx)) + emitFatalError( f"Can not handle conversion of python type {argType} to MLIR type.") diff --git a/python/runtime/mlir/py_register_dialects.cpp b/python/runtime/mlir/py_register_dialects.cpp index f87d7d479c..3dd5a66ff3 100644 --- a/python/runtime/mlir/py_register_dialects.cpp +++ b/python/runtime/mlir/py_register_dialects.cpp @@ -97,6 +97,50 @@ void registerQuakeDialectAndTypes(py::module &m) { return veqTy.getSize(); }, py::arg("veqTypeInstance")); + + mlir_type_subclass( + quakeMod, "StruqType", + [](MlirType type) { return unwrap(type).isa(); }) + .def_classmethod( + "get", + [](py::object cls, MlirContext ctx, py::list aggregateTypes) { + SmallVector inTys; + for (auto &t : aggregateTypes) + inTys.push_back(unwrap(t.cast())); + + return wrap(quake::StruqType::get(unwrap(ctx), inTys)); + }) + .def_classmethod("getNamed", + [](py::object cls, MlirContext ctx, + const std::string &name, py::list aggregateTypes) { + SmallVector inTys; + for (auto &t : aggregateTypes) + inTys.push_back(unwrap(t.cast())); + + return wrap( + quake::StruqType::get(unwrap(ctx), name, inTys)); + }) + .def_classmethod( + "getTypes", + [](py::object cls, MlirType structTy) { + auto ty = dyn_cast(unwrap(structTy)); + if (!ty) + throw std::runtime_error( + "invalid type passed to StruqType.getTypes(), must be a " + "quake.struq"); + std::vector ret; + for (auto &t : ty.getMembers()) + ret.push_back(wrap(t)); + return ret; + }) + .def_classmethod("getName", [](py::object cls, MlirType structTy) { + auto ty = dyn_cast(unwrap(structTy)); + if (!ty) + throw std::runtime_error( + "invalid type passed to StruqType.getName(), must be a " + "quake.struq"); + return ty.getName().getValue().str(); + }); } void registerCCDialectAndTypes(py::module &m) { diff --git a/python/tests/kernel/test_kernel_features.py b/python/tests/kernel/test_kernel_features.py index 52624cf6ed..b6267cb75d 100644 --- a/python/tests/kernel/test_kernel_features.py +++ b/python/tests/kernel/test_kernel_features.py @@ -194,12 +194,12 @@ def grover(N: int, M: int, oracle: Callable[[cudaq.qview], None]): def test_pauli_word_input(): h2_data = [ - 3, 1, 1, 3, 0.0454063, 0, 2, 0, 0, 0, 0.17028, 0, 0, 0, 2, 0, -0.220041, - -0, 1, 3, 3, 1, 0.0454063, 0, 0, 0, 0, 0, -0.106477, 0, 0, 2, 0, 0, - 0.17028, 0, 0, 0, 0, 2, -0.220041, -0, 3, 3, 1, 1, -0.0454063, -0, 2, 2, - 0, 0, 0.168336, 0, 2, 0, 2, 0, 0.1202, 0, 0, 2, 0, 2, 0.1202, 0, 2, 0, - 0, 2, 0.165607, 0, 0, 2, 2, 0, 0.165607, 0, 0, 0, 2, 2, 0.174073, 0, 1, - 1, 3, 3, -0.0454063, -0, 15 + 3, 1, 1, 3, 0.0454063, 0, 2, 0, 0, 0, 0.17028, 0, 0, 0, 2, 0, + -0.220041, -0, 1, 3, 3, 1, 0.0454063, 0, 0, 0, 0, 0, -0.106477, 0, 0, + 2, 0, 0, 0.17028, 0, 0, 0, 0, 2, -0.220041, -0, 3, 3, 1, 1, -0.0454063, + -0, 2, 2, 0, 0, 0.168336, 0, 2, 0, 2, 0, 0.1202, 0, 0, 2, 0, 2, 0.1202, + 0, 2, 0, 0, 2, 0.165607, 0, 0, 2, 2, 0, 0.165607, 0, 0, 0, 2, 2, + 0.174073, 0, 1, 1, 3, 3, -0.0454063, -0, 15 ] h = cudaq.SpinOperator(h2_data, 4) @@ -242,12 +242,12 @@ def test(theta: float, paulis: list[cudaq.pauli_word]): def test_exp_pauli(): h2_data = [ - 3, 1, 1, 3, 0.0454063, 0, 2, 0, 0, 0, 0.17028, 0, 0, 0, 2, 0, -0.220041, - -0, 1, 3, 3, 1, 0.0454063, 0, 0, 0, 0, 0, -0.106477, 0, 0, 2, 0, 0, - 0.17028, 0, 0, 0, 0, 2, -0.220041, -0, 3, 3, 1, 1, -0.0454063, -0, 2, 2, - 0, 0, 0.168336, 0, 2, 0, 2, 0, 0.1202, 0, 0, 2, 0, 2, 0.1202, 0, 2, 0, - 0, 2, 0.165607, 0, 0, 2, 2, 0, 0.165607, 0, 0, 0, 2, 2, 0.174073, 0, 1, - 1, 3, 3, -0.0454063, -0, 15 + 3, 1, 1, 3, 0.0454063, 0, 2, 0, 0, 0, 0.17028, 0, 0, 0, 2, 0, + -0.220041, -0, 1, 3, 3, 1, 0.0454063, 0, 0, 0, 0, 0, -0.106477, 0, 0, + 2, 0, 0, 0.17028, 0, 0, 0, 0, 2, -0.220041, -0, 3, 3, 1, 1, -0.0454063, + -0, 2, 2, 0, 0, 0.168336, 0, 2, 0, 2, 0, 0.1202, 0, 0, 2, 0, 2, 0.1202, + 0, 2, 0, 0, 2, 0.165607, 0, 0, 2, 2, 0, 0.165607, 0, 0, 0, 2, 2, + 0.174073, 0, 1, 1, 3, 3, -0.0454063, -0, 15 ] h = cudaq.SpinOperator(h2_data, 4) @@ -1725,6 +1725,102 @@ def run(): run() +def test_disallow_hybrid_types(): + from dataclasses import dataclass + # Ensure we don't allow hybrid type s + @dataclass + class hybrid: + q: cudaq.qview + i: int + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def test(): + q = cudaq.qvector(2) + h = hybrid(q, 1) + + test() + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def testtest(h: hybrid): + x(h.q[h.i]) + + testtest.compile() + + +def test_disallow_quantum_struct_return(): + from dataclasses import dataclass + # Ensure we don't allow hybrid type s + @dataclass + class T: + q: cudaq.qview + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def test() -> T: + q = cudaq.qvector(2) + h = T(q) + return h + + test() + +def test_disallow_recursive_quantum_struct(): + from dataclasses import dataclass + @dataclass + class T: + q: cudaq.qview + + @dataclass + class Holder: + t : T + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def test(): + q = cudaq.qvector(2) + t = T(q) + hh = Holder(t) + + print(test) + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def test(hh : Holder): + pass + + print(test) + +def test_disallow_struct_with_methods(): + from dataclasses import dataclass + @dataclass + class T: + q: cudaq.qview + def doSomething(self): + pass + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def test(t : T): + pass + + print(test) + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def test(): + q = cudaq.qvector(2) + t = T(q) + print(test) + + @skipIfPythonLessThan39 def test_issue_9(): diff --git a/python/tests/mlir/quantum_struct.py b/python/tests/mlir/quantum_struct.py new file mode 100644 index 0000000000..8ed8081bbe --- /dev/null +++ b/python/tests/mlir/quantum_struct.py @@ -0,0 +1,39 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +# RUN: PYTHONPATH=../../ pytest -rP %s | FileCheck %s + + +import pytest +import cudaq +from dataclasses import dataclass + +def test_quantum_struct(): + @dataclass + class patch: + q : cudaq.qview + r : cudaq.qview + + @cudaq.kernel + def entry(): + q = cudaq.qvector(2) + r = cudaq.qvector(2) + p = patch(q, r) + h(p.r[0]) + + print(entry) + +# CHECK-LABEL: func.func @__nvqpp__mlirgen__entry() +# CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<2> +# CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2> +# The struq type is erased in this example. +# CHECK: %[[VAL_2:.*]] = quake.extract_ref %[[VAL_1]][0] : (!quake.veq<2>) -> !quake.ref +# CHECK: quake.h %[[VAL_2]] : (!quake.ref) -> () +# CHECK: return +# CHECK: } + diff --git a/python/tests/mlir/quantum_type.py b/python/tests/mlir/quantum_type.py index b204c10f77..3fec7bbcaf 100644 --- a/python/tests/mlir/quantum_type.py +++ b/python/tests/mlir/quantum_type.py @@ -6,7 +6,8 @@ # the terms of the Apache License 2.0 which accompanies this distribution. # # ============================================================================ # -# RUN: PYTHONPATH=../../ pytest -rP %s | FileCheck %s +# Workaround for kernels that may appear in jumbled order. +# RUN: PYTHONPATH=../../ pytest -rP %s > %t && FileCheck %s < %t && FileCheck --check-prefix=NAUGHTY %s < %t && FileCheck --check-prefix=NICE %s < %t import pytest @@ -47,41 +48,35 @@ def run(): # Test here is that it compiles and runs successfully print(run) -# CHECK-LABEL: func.func @__nvqpp__mlirgen__logicalH( -# CHECK-SAME: %[[VAL_0:.*]]: !cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}>) attributes {"cudaq-entrypoint"} { -# CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64 -# CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64 -# CHECK: %[[VAL_3:.*]] = cc.extract_value %[[VAL_0]][0] : (!cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}>) -> !quake.veq +# NAUGHTY-LABEL: func.func @__nvqpp__mlirgen__logicalH( +# NAUGHTY-SAME: %[[VAL_0:.*]]: !quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) { +# NAUGHTY: %[[VAL_3:.*]] = quake.get_member %[[VAL_0]][0] : (!quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) -> !quake.veq +# NAUGHTY: %[[VAL_4:.*]] = quake.veq_size %[[VAL_3]] : (!quake.veq) -> i64 +# NAUGHTY: return +# NAUGHTY: } + +# NICE-LABEL: func.func @__nvqpp__mlirgen__logicalX( +# NICE-SAME: %[[VAL_0:.*]]: !quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) { +# NICE: %[[VAL_3:.*]] = quake.get_member %[[VAL_0]][1] : (!quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) -> !quake.veq +# NICE: %[[VAL_4:.*]] = quake.veq_size %[[VAL_3]] : (!quake.veq) -> i64 +# NICE: return +# NICE: } + +# CHECK-LABEL: func.func @__nvqpp__mlirgen__logicalZ( +# CHECK-SAME: %[[VAL_0:.*]]: !quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) { +# CHECK: %[[VAL_3:.*]] = quake.get_member %[[VAL_0]][2] : (!quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) -> !quake.veq # CHECK: %[[VAL_4:.*]] = quake.veq_size %[[VAL_3]] : (!quake.veq) -> i64 -# CHECK: %[[VAL_5:.*]] = cc.loop while ((%[[VAL_6:.*]] = %[[VAL_2]]) -> (i64)) { -# CHECK: %[[VAL_7:.*]] = arith.cmpi slt, %[[VAL_6]], %[[VAL_4]] : i64 -# CHECK: cc.condition %[[VAL_7]](%[[VAL_6]] : i64) -# CHECK: } do { -# CHECK: ^bb0(%[[VAL_8:.*]]: i64): -# CHECK: %[[VAL_9:.*]] = quake.extract_ref %[[VAL_3]]{{\[}}%[[VAL_8]]] : (!quake.veq, i64) -> !quake.ref -# CHECK: quake.h %[[VAL_9]] : (!quake.ref) -> () -# CHECK: cc.continue %[[VAL_8]] : i64 -# CHECK: } step { -# CHECK: ^bb0(%[[VAL_10:.*]]: i64): -# CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_1]] : i64 -# CHECK: cc.continue %[[VAL_11]] : i64 -# CHECK: } {invariant} # CHECK: return # CHECK: } -# CHECK-LABEL: func.func @__nvqpp__mlirgen__run() attributes {"cudaq-entrypoint"} { +# CHECK-LABEL: func.func @__nvqpp__mlirgen__run() # CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<2> -# CHECK: %[[VAL_1:.*]] = quake.relax_size %[[VAL_0]] : (!quake.veq<2>) -> !quake.veq +# CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2> # CHECK: %[[VAL_2:.*]] = quake.alloca !quake.veq<2> -# CHECK: %[[VAL_3:.*]] = quake.relax_size %[[VAL_2]] : (!quake.veq<2>) -> !quake.veq -# CHECK: %[[VAL_4:.*]] = quake.alloca !quake.veq<2> -# CHECK: %[[VAL_5:.*]] = quake.relax_size %[[VAL_4]] : (!quake.veq<2>) -> !quake.veq -# CHECK: %[[VAL_6:.*]] = cc.undef !cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}> -# CHECK: %[[VAL_7:.*]] = cc.insert_value %[[VAL_1]], %[[VAL_6]][0] : (!cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}>, !quake.veq) -> !cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}> -# CHECK: %[[VAL_8:.*]] = cc.insert_value %[[VAL_3]], %[[VAL_7]][1] : (!cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}>, !quake.veq) -> !cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}> -# CHECK: %[[VAL_9:.*]] = cc.insert_value %[[VAL_5]], %[[VAL_8]][2] : (!cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}>, !quake.veq) -> !cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}> -# CHECK: call @__nvqpp__mlirgen__logicalH(%[[VAL_9]]) : (!cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}>) -> () -# CHECK: call @__nvqpp__mlirgen__logicalX(%[[VAL_9]]) : (!cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}>) -> () -# CHECK: call @__nvqpp__mlirgen__logicalZ(%[[VAL_9]]) : (!cc.struct<"patch" {!quake.veq, !quake.veq, !quake.veq}>) -> () +# CHECK: %[[VAL_3:.*]] = quake.make_struq %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : (!quake.veq<2>, !quake.veq<2>, !quake.veq<2>) -> !quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq> +# CHECK: call @__nvqpp__mlirgen__logicalH(%[[VAL_3]]) : (!quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) -> () +# CHECK: call @__nvqpp__mlirgen__logicalX(%[[VAL_3]]) : (!quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) -> () +# CHECK: call @__nvqpp__mlirgen__logicalZ(%[[VAL_3]]) : (!quake.struq<"patch": !quake.veq, !quake.veq, !quake.veq>) -> () # CHECK: return -# CHECK: } \ No newline at end of file +# CHECK: } + diff --git a/python/tests/mlir/test_output_qir.py b/python/tests/mlir/test_output_qir.py index 22f50704d7..eab4ca9171 100644 --- a/python/tests/mlir/test_output_qir.py +++ b/python/tests/mlir/test_output_qir.py @@ -22,9 +22,9 @@ def ghz(numQubits: int): for i, qubitIdx in enumerate(range(numQubits - 1)): x.ctrl(qubits[i], qubits[qubitIdx + 1]) - print(cudaq.to_qir(ghz)) + print(cudaq.translate(ghz, format="qir")) ghz_synth = cudaq.synthesize(ghz, 5) - print(cudaq.to_qir(ghz_synth, profile='qir-base')) + print(cudaq.translate(ghz_synth, format='qir-base')) # CHECK: %[[VAL_0:.*]] = tail call diff --git a/targettests/execution/quantum_struct.cpp b/targettests/execution/quantum_struct.cpp new file mode 100644 index 0000000000..929a2cd615 --- /dev/null +++ b/targettests/execution/quantum_struct.cpp @@ -0,0 +1,70 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// clang-format off +// RUN: nvq++ %cpp_std --enable-mlir %s -o %t && %t | FileCheck %s +// clang-format on + +#include +#include + +struct PureQuantumStruct { + cudaq::qview<> view1; + cudaq::qview<> view2; +}; + +struct Fehu { + void operator()(cudaq::qview<> v) __qpu__ { h(v); } +}; + +struct Ansuz { + void operator()(cudaq::qview<> v) __qpu__ { x(v); } +}; + +struct Uruz { + void operator()(PureQuantumStruct group) __qpu__ { + Ansuz{}(group.view1); + Fehu{}(group.view1); + Fehu{}(group.view2); + Ansuz{}(group.view2); + } +}; + +struct Thurisaz { + void operator()() __qpu__ { + cudaq::qvector v1(2); + cudaq::qvector v2(3); + PureQuantumStruct pqs{v1, v2}; + Uruz{}(pqs); + mz(v1); + mz(v2); + } +}; + +int main() { + auto result = cudaq::sample(Thurisaz{}); + int flags[1 << 5] = {0}; + for (auto &&[b, c] : result) { + int off = std::stoi(b, nullptr, 2); + if (off >= (1 << 5) || off < 0) { + std::cout << "Amazingly incorrect: " << b << '\n'; + return 1; + } + flags[off] = 1 + c; + } + for (int i = 0; i < (1 << 5); ++i) { + if (flags[i] == 0) { + std::cout << "FAILED!\n"; + return 1; + } + } + std::cout << "Wahoo!\n"; + return 0; +} + +// CHECK: Wahoo diff --git a/test/AST-Quake/pure_quantum_struct.cpp b/test/AST-Quake/pure_quantum_struct.cpp new file mode 100644 index 0000000000..6f257d4d0d --- /dev/null +++ b/test/AST-Quake/pure_quantum_struct.cpp @@ -0,0 +1,157 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// clang-format off +// RUN: cudaq-quake %cpp_std %s | cudaq-opt | FileCheck %s +// RUN: cudaq-quake %cpp_std %s | cudaq-translate --convert-to=qir | FileCheck --check-prefix=QIR %s +// clang-format on + +#include "cudaq.h" + +struct test { + cudaq::qview<> q; + cudaq::qview<> r; +}; + +__qpu__ void applyH(cudaq::qubit &q) { h(q); } +__qpu__ void applyX(cudaq::qubit &q) { x(q); } +__qpu__ void kernel(test t) { + h(t.q); + s(t.r); + + applyH(t.q[0]); + applyX(t.r[0]); +} + +// clang-format off +// CHECK-LABEL: func.func @__nvqpp__mlirgen__function_kernel._Z6kernel4test( +// CHECK-SAME: %[[VAL_0:.*]]: !quake.struq, !quake.veq>) attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_3:.*]] = quake.get_member %[[VAL_0]][0] : (!quake.struq, !quake.veq>) -> !quake.veq +// CHECK: %[[VAL_4:.*]] = quake.veq_size %[[VAL_3]] : (!quake.veq) -> i64 +// CHECK: %[[VAL_12:.*]] = quake.get_member %[[VAL_0]][1] : (!quake.struq, !quake.veq>) -> !quake.veq +// CHECK: %[[VAL_13:.*]] = quake.veq_size %[[VAL_12]] : (!quake.veq) -> i64 +// CHECK: %[[VAL_21:.*]] = quake.get_member %[[VAL_0]][0] : (!quake.struq, !quake.veq>) -> !quake.veq +// CHECK: %[[VAL_22:.*]] = quake.extract_ref %[[VAL_21]][0] : (!quake.veq) -> !quake.ref +// CHECK: call @__nvqpp__mlirgen__function_applyH._Z6applyHRN5cudaq5quditILm2EEE(%[[VAL_22]]) : (!quake.ref) -> () +// CHECK: %[[VAL_23:.*]] = quake.get_member %[[VAL_0]][1] : (!quake.struq, !quake.veq>) -> !quake.veq +// CHECK: %[[VAL_24:.*]] = quake.extract_ref %[[VAL_23]][0] : (!quake.veq) -> !quake.ref +// CHECK: call @__nvqpp__mlirgen__function_applyX._Z6applyXRN5cudaq5quditILm2EEE(%[[VAL_24]]) : (!quake.ref) -> () +// CHECK: return +// CHECK: } +// clang-format on + +__qpu__ void entry_initlist() { + cudaq::qvector q(2), r(2); + test tt{q, r}; + kernel(tt); +} + +// clang-format off +// CHECK-LABEL: func.func @__nvqpp__mlirgen__function_entry_initlist._Z14entry_initlistv() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} { +// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<2> +// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2> +// CHECK: %[[VAL_2:.*]] = quake.make_struq %[[VAL_0]], %[[VAL_1]] : (!quake.veq<2>, !quake.veq<2>) -> !quake.struq, !quake.veq> +// CHECK: call @__nvqpp__mlirgen__function_kernel._Z6kernel4test(%[[VAL_2]]) : (!quake.struq, !quake.veq>) -> () +// CHECK: return +// CHECK: } +// clang-format on + +__qpu__ void entry_ctor() { + cudaq::qvector q(2), r(2); + test tt(q, r); + h(tt.r[0]); +} + +// clang-format off +// CHECK-LABEL: func.func @__nvqpp__mlirgen__function_entry_ctor._Z10entry_ctorv() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} { +// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<2> +// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2> +// CHECK: %[[VAL_2:.*]] = quake.extract_ref %[[VAL_1]][0] : (!quake.veq<2>) -> !quake.ref +// CHECK: quake.h %[[VAL_2]] : (!quake.ref) -> () +// CHECK: return +// CHECK: } +// clang-format on + +// clang-format off +// QIR-LABEL: define void @__nvqpp__mlirgen__function_kernel._Z6kernel4test({ +// QIR-SAME: %[[VAL_0:.*]]*, %[[VAL_0]]* } %[[VAL_1:.*]]) local_unnamed_addr { +// QIR: %[[VAL_2:.*]] = extractvalue { %[[VAL_0]]*, %[[VAL_0]]* } %[[VAL_1]], 0 +// QIR: %[[VAL_3:.*]] = tail call i64 @__quantum__rt__array_get_size_1d(%[[VAL_0]]* %[[VAL_2]]) +// QIR: %[[VAL_4:.*]] = icmp sgt i64 %[[VAL_3]], 0 +// QIR: br i1 %[[VAL_4]], label %[[VAL_5:.*]], label %[[VAL_6:.*]] +// QIR: .lr.ph: ; preds = %[[VAL_7:.*]], %[[VAL_5]] +// QIR: %[[VAL_8:.*]] = phi i64 [ %[[VAL_9:.*]], %[[VAL_5]] ], [ 0, %[[VAL_7]] ] +// QIR: %[[VAL_10:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_0]]* %[[VAL_2]], i64 %[[VAL_8]]) +// QIR: %[[VAL_11:.*]] = bitcast i8* %[[VAL_10]] to %[[VAL_12:.*]]** +// QIR: %[[VAL_13:.*]] = load %[[VAL_12]]*, %[[VAL_12]]** %[[VAL_11]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_12]]* %[[VAL_13]]) +// QIR: %[[VAL_9]] = add nuw nsw i64 %[[VAL_8]], 1 +// QIR: %[[VAL_14:.*]] = icmp eq i64 %[[VAL_9]], %[[VAL_3]] +// QIR: br i1 %[[VAL_14]], label %[[VAL_6]], label %[[VAL_5]] +// QIR: ._crit_edge: ; preds = %[[VAL_5]], %[[VAL_7]] +// QIR: %[[VAL_15:.*]] = extractvalue { %[[VAL_0]]*, %[[VAL_0]]* } %[[VAL_1]], 1 +// QIR: %[[VAL_16:.*]] = tail call i64 @__quantum__rt__array_get_size_1d(%[[VAL_0]]* %[[VAL_15]]) +// QIR: %[[VAL_17:.*]] = icmp sgt i64 %[[VAL_16]], 0 +// QIR: br i1 %[[VAL_17]], label %[[VAL_18:.*]], label %[[VAL_19:.*]] +// QIR: .lr.ph3: ; preds = %[[VAL_6]], %[[VAL_18]] +// QIR: %[[VAL_20:.*]] = phi i64 [ %[[VAL_21:.*]], %[[VAL_18]] ], [ 0, %[[VAL_6]] ] +// QIR: %[[VAL_22:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_0]]* %[[VAL_15]], i64 %[[VAL_20]]) +// QIR: %[[VAL_23:.*]] = bitcast i8* %[[VAL_22]] to %[[VAL_12]]** +// QIR: %[[VAL_24:.*]] = load %[[VAL_12]]*, %[[VAL_12]]** %[[VAL_23]], align 8 +// QIR: tail call void @__quantum__qis__s(%[[VAL_12]]* %[[VAL_24]]) +// QIR: %[[VAL_21]] = add nuw nsw i64 %[[VAL_20]], 1 +// QIR: %[[VAL_25:.*]] = icmp eq i64 %[[VAL_21]], %[[VAL_16]] +// QIR: br i1 %[[VAL_25]], label %[[VAL_19]], label %[[VAL_18]] +// QIR: ._crit_edge4: ; preds = %[[VAL_18]], %[[VAL_6]] +// QIR: %[[VAL_26:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_0]]* %[[VAL_2]], i64 0) +// QIR: %[[VAL_27:.*]] = bitcast i8* %[[VAL_26]] to %[[VAL_12]]** +// QIR: %[[VAL_28:.*]] = load %[[VAL_12]]*, %[[VAL_12]]** %[[VAL_27]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_12]]* %[[VAL_28]]) +// QIR: %[[VAL_29:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_0]]* %[[VAL_15]], i64 0) +// QIR: %[[VAL_30:.*]] = bitcast i8* %[[VAL_29]] to %[[VAL_12]]** +// QIR: %[[VAL_31:.*]] = load %[[VAL_12]]*, %[[VAL_12]]** %[[VAL_30]], align 8 +// QIR: tail call void @__quantum__qis__x(%[[VAL_12]]* %[[VAL_31]]) +// QIR: ret void +// QIR: } + +// QIR-LABEL: define void @__nvqpp__mlirgen__function_entry_initlist._Z14entry_initlistv() local_unnamed_addr { +// QIR: %[[VAL_0:.*]] = tail call %[[VAL_1:.*]]* @__quantum__rt__qubit_allocate_array(i64 4) +// QIR: %[[VAL_2:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_1]]* %[[VAL_0]], i64 0) +// QIR: %[[VAL_3:.*]] = bitcast i8* %[[VAL_2]] to %[[VAL_4:.*]]** +// QIR: %[[VAL_5:.*]] = load %[[VAL_4]]*, %[[VAL_4]]** %[[VAL_3]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_4]]* %[[VAL_5]]) +// QIR: %[[VAL_6:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_1]]* %[[VAL_0]], i64 1) +// QIR: %[[VAL_7:.*]] = bitcast i8* %[[VAL_6]] to %[[VAL_4]]** +// QIR: %[[VAL_8:.*]] = load %[[VAL_4]]*, %[[VAL_4]]** %[[VAL_7]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_4]]* %[[VAL_8]]) +// QIR: %[[VAL_9:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_1]]* %[[VAL_0]], i64 2) +// QIR: %[[VAL_10:.*]] = bitcast i8* %[[VAL_9]] to %[[VAL_4]]** +// QIR: %[[VAL_11:.*]] = load %[[VAL_4]]*, %[[VAL_4]]** %[[VAL_10]], align 8 +// QIR: tail call void @__quantum__qis__s(%[[VAL_4]]* %[[VAL_11]]) +// QIR: %[[VAL_12:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_1]]* %[[VAL_0]], i64 3) +// QIR: %[[VAL_13:.*]] = bitcast i8* %[[VAL_12]] to %[[VAL_4]]** +// QIR: %[[VAL_14:.*]] = load %[[VAL_4]]*, %[[VAL_4]]** %[[VAL_13]], align 8 +// QIR: tail call void @__quantum__qis__s(%[[VAL_4]]* %[[VAL_14]]) +// QIR: tail call void @__quantum__qis__h(%[[VAL_4]]* %[[VAL_5]]) +// QIR: tail call void @__quantum__qis__x(%[[VAL_4]]* %[[VAL_11]]) +// QIR: tail call void @__quantum__rt__qubit_release_array(%[[VAL_1]]* %[[VAL_0]]) +// QIR: ret void +// QIR: } + +// QIR-LABEL: define void @__nvqpp__mlirgen__function_entry_ctor._Z10entry_ctorv() local_unnamed_addr { +// QIR: %[[VAL_0:.*]] = tail call %[[VAL_1:.*]]* @__quantum__rt__qubit_allocate_array(i64 4) +// QIR: %[[VAL_2:.*]] = tail call i8* @__quantum__rt__array_get_element_ptr_1d(%[[VAL_1]]* %[[VAL_0]], i64 2) +// QIR: %[[VAL_3:.*]] = bitcast i8* %[[VAL_2]] to %[[VAL_4:.*]]** +// QIR: %[[VAL_5:.*]] = load %[[VAL_4]]*, %[[VAL_4]]** %[[VAL_3]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_4]]* %[[VAL_5]]) +// QIR: tail call void @__quantum__rt__qubit_release_array(%[[VAL_1]]* %[[VAL_0]]) +// QIR: ret void +// QIR: } +// clang-format on diff --git a/test/AST-error/kernel_invalid_argument-2.cpp b/test/AST-error/kernel_invalid_argument.cpp similarity index 100% rename from test/AST-error/kernel_invalid_argument-2.cpp rename to test/AST-error/kernel_invalid_argument.cpp diff --git a/test/AST-error/kernel_with_member_functions.cpp b/test/AST-error/kernel_with_member_functions.cpp new file mode 100644 index 0000000000..7275e15f1c --- /dev/null +++ b/test/AST-error/kernel_with_member_functions.cpp @@ -0,0 +1,23 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// REQUIRES: c++20 +// RUN: cudaq-quake %s -verify + +#include "cudaq.h" + +// expected-error@+1 {{struct with user-defined methods is not allowed}} +struct test { + cudaq::qview<> q; + int myMethod() { return 0; } +}; + +__qpu__ void kernel() { + cudaq::qvector q(2); + test t(q); +} diff --git a/test/AST-error/quantum_struct_declarations.cpp b/test/AST-error/quantum_struct_declarations.cpp new file mode 100644 index 0000000000..e2862633a8 --- /dev/null +++ b/test/AST-error/quantum_struct_declarations.cpp @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// REQUIRES: c++20 +// RUN: cudaq-quake %s -verify + +#include "cudaq.h" + +// expected-error@+1 {{quantum struct has invalid member type}} +struct error1 { + cudaq::qvector<4> wrong; +}; + +__qpu__ void bug1(error1&); + +// expected-error@+1 {{quantum struct has invalid member type}} +struct error2 { + cudaq::qubit cubit; +}; + +__qpu__ void bug2(error2&); + +// expected-error@+2 {{quantum struct has invalid member type}} +// expected-error@+1 {{quantum struct has invalid member type}} +struct error3 { + cudaq::qubit nope; + cudaq::qvector<2> sorry; +}; + +__qpu__ void bug3(error3&); + +__qpu__ void funny() { + error1 e1; + error2 e2; + error3 e3; + bug1(e1); + bug2(e2); + bug3(e3); +} diff --git a/test/AST-error/quantum_struct_signature.cpp b/test/AST-error/quantum_struct_signature.cpp new file mode 100644 index 0000000000..88e39f132e --- /dev/null +++ b/test/AST-error/quantum_struct_signature.cpp @@ -0,0 +1,23 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// REQUIRES: c++20 +// RUN: cudaq-quake %s -verify + +#include "cudaq.h" + +struct test { + cudaq::qubit &r; + cudaq::qview<> q; +}; + +// expected-error@+1 {{kernel result type not supported}} +__qpu__ test kernel(cudaq::qubit &q, cudaq::qview<> qq) { + test result(q, qq); + return result; +} diff --git a/test/AST-error/quantum_struct_with_struct_member.cpp b/test/AST-error/quantum_struct_with_struct_member.cpp new file mode 100644 index 0000000000..74fe2024f6 --- /dev/null +++ b/test/AST-error/quantum_struct_with_struct_member.cpp @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// REQUIRES: c++20 +// RUN: cudaq-quake %s -verify + +#include "cudaq.h" + +struct s { + cudaq::qview<> s; +}; +// expected-error@+2{{recursive quantum struct types are not allowed}} +// expected-error@+1{{quantum struct has invalid member type}} +struct test { + cudaq::qview<> q; + cudaq::qview<> r; + s s; +}; +__qpu__ void entry_ctor() { + cudaq::qvector q(2), r(2); + s s(q); + test tt(q, r, s); + h(tt.r[0]); +} diff --git a/test/AST-error/struct_quantum_and_classical.cpp b/test/AST-error/struct_quantum_and_classical.cpp new file mode 100644 index 0000000000..cdbdaf1517 --- /dev/null +++ b/test/AST-error/struct_quantum_and_classical.cpp @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// REQUIRES: c++20 +// RUN: cudaq-quake %s -verify + +#include "cudaq.h" + +// expected-error@+1 {{hybrid quantum-classical struct types are not allowed}} +struct test { + int i; + double d; + cudaq::qview<> q; +}; + +__qpu__ void hello(cudaq::qubit &q) { h(q); } + +__qpu__ void kernel(test t) { + h(t.q); + hello(t.q[0]); +} + +__qpu__ void entry(int i) { + cudaq::qvector q(i); + test tt{1, 2.2, q}; + // this fails non-default ctor ConvertExpr:2899, + // but this is not what we are testing here + // kernel(tt); +} diff --git a/test/Quake/invalid.qke b/test/Quake/invalid.qke new file mode 100644 index 0000000000..1dd79e84c0 --- /dev/null +++ b/test/Quake/invalid.qke @@ -0,0 +1,56 @@ +// ========================================================================== // +// Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// RUN: cudaq-opt %s -split-input-file -verify-diagnostics + +func.func @test_struq() { + %0 = quake.alloca !quake.veq<4> + %1 = arith.constant 1 : i32 + %2 = arith.constant 2.0 : f32 + // expected-error@+1 {{must be non-struct quantum reference type}} + %6 = quake.make_struq %0, %1, %2 : (!quake.veq<4>, i32, f32) -> !quake.struq, i32, f32> + return +} + +// ----- + +func.func @test_struq() { + %0 = quake.alloca !quake.veq<4> + %1 = quake.alloca !quake.veq<7> + // expected-error@+1 {{member type not compatible with operand type}} + %6 = quake.make_struq %0, %1 : (!quake.veq<4>, !quake.veq<7>) -> !quake.struq, !quake.veq<8>> + return +} + +// ----- + +func.func @test_struq() { + %0 = quake.alloca !quake.veq<4> + %1 = quake.alloca !quake.veq<7> + // expected-error@+1 {{result type has different member count than operands}} + %6 = quake.make_struq %0, %1 : (!quake.veq<4>, !quake.veq<7>) -> !quake.struq> + return +} + +// ----- + +func.func @test_struq() { + %0 = quake.alloca !quake.veq<4> + %1 = quake.alloca !quake.veq<7> + // expected-error@+1 {{result type has different member count than operands}} + %6 = quake.make_struq %0, %1 : (!quake.veq<4>, !quake.veq<7>) -> !quake.struq, !quake.veq, !quake.veq> + return +} + +// ----- + +func.func @test_struq(%arg : !quake.struq, !quake.veq<2>, !quake.veq<3>>) { + // expected-error@+1 {{invalid index}} + %6 = quake.get_member %arg[3] : (!quake.struq, !quake.veq<2>, !quake.veq<3>>) -> !quake.veq<1> + return +} diff --git a/test/Quake/roundtrip-ops.qke b/test/Quake/roundtrip-ops.qke index bca0639fec..2d094e4060 100644 --- a/test/Quake/roundtrip-ops.qke +++ b/test/Quake/roundtrip-ops.qke @@ -801,3 +801,41 @@ func.func @indirect_callable2(%arg : !cc.indirect_callable<(i32) -> i64>) -> i64 // CHECK: %[[VAL_2:.*]] = cc.call_indirect_callable %[[VAL_0]], %[[VAL_1]] : (!cc.indirect_callable<(i32) -> i64>, i32) -> i64 // CHECK: return %[[VAL_2]] : i64 // CHECK: } + +func.func @quantum_product_type() { + %0 = quake.alloca !quake.veq<3> + %1 = quake.alloca !quake.veq<4> + %2 = quake.make_struq %0, %1 : (!quake.veq<3>, !quake.veq<4>) -> !quake.struq, !quake.veq> + %3 = quake.get_member %2[0] : (!quake.struq, !quake.veq>) -> !quake.veq + %4 = quake.get_member %2[1] : (!quake.struq, !quake.veq>) -> !quake.veq + %10 = quake.alloca !quake.veq<5> + %11 = quake.alloca !quake.veq<6> + %12 = quake.make_struq %10, %11 : (!quake.veq<5>, !quake.veq<6>) -> !quake.struq<"gumby": !quake.veq, !quake.veq> + %13 = quake.get_member %12[0] : (!quake.struq<"gumby": !quake.veq, !quake.veq>) -> !quake.veq + %14 = quake.get_member %12[1] : (!quake.struq<"gumby": !quake.veq, !quake.veq>) -> !quake.veq + %20 = quake.alloca !quake.veq<7> + %21 = quake.alloca !quake.veq<8> + %22 = quake.make_struq %20, %21 : (!quake.veq<7>, !quake.veq<8>) -> !quake.struq, !quake.veq<8>> + %23 = quake.get_member %22[0] : (!quake.struq, !quake.veq<8>>) -> !quake.veq<7> + %24 = quake.get_member %22[1] : (!quake.struq, !quake.veq<8>>) -> !quake.veq<8> + return + } + +// CHECK-LABEL: func.func @quantum_product_type() { +// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<3> +// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<4> +// CHECK: %[[VAL_2:.*]] = quake.make_struq %[[VAL_0]], %[[VAL_1]] : (!quake.veq<3>, !quake.veq<4>) -> !quake.struq, !quake.veq> +// CHECK: %[[VAL_3:.*]] = quake.get_member %[[VAL_2]][0] : (!quake.struq, !quake.veq>) -> !quake.veq +// CHECK: %[[VAL_4:.*]] = quake.get_member %[[VAL_2]][1] : (!quake.struq, !quake.veq>) -> !quake.veq +// CHECK: %[[VAL_5:.*]] = quake.alloca !quake.veq<5> +// CHECK: %[[VAL_6:.*]] = quake.alloca !quake.veq<6> +// CHECK: %[[VAL_7:.*]] = quake.make_struq %[[VAL_5]], %[[VAL_6]] : (!quake.veq<5>, !quake.veq<6>) -> !quake.struq<"gumby": !quake.veq, !quake.veq> +// CHECK: %[[VAL_8:.*]] = quake.get_member %[[VAL_7]][0] : (!quake.struq<"gumby": !quake.veq, !quake.veq>) -> !quake.veq +// CHECK: %[[VAL_9:.*]] = quake.get_member %[[VAL_7]][1] : (!quake.struq<"gumby": !quake.veq, !quake.veq>) -> !quake.veq +// CHECK: %[[VAL_10:.*]] = quake.alloca !quake.veq<7> +// CHECK: %[[VAL_11:.*]] = quake.alloca !quake.veq<8> +// CHECK: %[[VAL_12:.*]] = quake.make_struq %[[VAL_10]], %[[VAL_11]] : (!quake.veq<7>, !quake.veq<8>) -> !quake.struq, !quake.veq<8>> +// CHECK: %[[VAL_13:.*]] = quake.get_member %[[VAL_12]][0] : (!quake.struq, !quake.veq<8>>) -> !quake.veq<7> +// CHECK: %[[VAL_14:.*]] = quake.get_member %[[VAL_12]][1] : (!quake.struq, !quake.veq<8>>) -> !quake.veq<8> +// CHECK: return +// CHECK: }