From 02ac2f2bdccfc205998e95fb5c075a14f89d6aae Mon Sep 17 00:00:00 2001 From: Eric Schweitz Date: Wed, 10 Jul 2024 10:44:18 -0700 Subject: [PATCH] [kernel_builder] Use quake.reset with a veq argument. Simplify the kernel_builder to use both qubit and veq forms of the reset op. Add a pattern to expand measurements that also expands resets. Requires #1910 to be merged first. --- include/cudaq/Optimizer/Transforms/Passes.td | 3 ++ lib/Optimizer/CodeGen/QuakeToLLVM.cpp | 3 +- .../Transforms/ExpandMeasurements.cpp | 29 ++++++++++++++++++- runtime/cudaq/builder/kernel_builder.cpp | 25 +--------------- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index 9deb13b9360..046c124edb6 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -197,6 +197,9 @@ def ExpandMeasurements : Pass<"expand-measurements"> { The target may only support measuring a single qubit however. This pass expands these ops in list format into a series of measurements (including loops) on individual qubits and into a single `std::vector` result. + + The `reset` op can also take a veq argument and this pass will also expand + that to a series of `reset` operations on single qubits. }]; let dependentDialects = ["cudaq::cc::CCDialect", "mlir::LLVM::LLVMDialect"]; diff --git a/lib/Optimizer/CodeGen/QuakeToLLVM.cpp b/lib/Optimizer/CodeGen/QuakeToLLVM.cpp index f1828999ed8..78ce099c515 100644 --- a/lib/Optimizer/CodeGen/QuakeToLLVM.cpp +++ b/lib/Optimizer/CodeGen/QuakeToLLVM.cpp @@ -385,11 +385,10 @@ class ResetRewrite : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto parentModule = instOp->getParentOfType(); auto context = parentModule->getContext(); - std::string qirQisPrefix{cudaq::opt::QIRQISPrefix}; std::string instName = instOp->getName().stripDialect().str(); // Get the reset QIR function name - auto qirFunctionName = qirQisPrefix + instName; + auto qirFunctionName = cudaq::opt::QIRQISPrefix + instName; // Create the qubit pointer type auto qirQubitPointerType = cudaq::opt::getQubitType(context); diff --git a/lib/Optimizer/Transforms/ExpandMeasurements.cpp b/lib/Optimizer/Transforms/ExpandMeasurements.cpp index 2c369a6cdbb..b2e690ae393 100644 --- a/lib/Optimizer/Transforms/ExpandMeasurements.cpp +++ b/lib/Optimizer/Transforms/ExpandMeasurements.cpp @@ -122,6 +122,30 @@ using MxRewrite = ExpandRewritePattern; using MyRewrite = ExpandRewritePattern; using MzRewrite = ExpandRewritePattern; +/// Convert a `quake.reset` with a `veq` argument into a loop over the elements +/// of the `veq` and `quake.reset` on each of them. +class ResetRewrite : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::ResetOp resetOp, + PatternRewriter &rewriter) const override { + auto loc = resetOp.getLoc(); + auto veqArg = resetOp.getTargets(); + auto i64Ty = rewriter.getI64Type(); + Value vecSz = rewriter.create(loc, i64Ty, veqArg); + cudaq::opt::factory::createInvariantLoop( + rewriter, loc, vecSz, + [&](OpBuilder &builder, Location loc, Region &, Block &block) { + Value iv = block.getArgument(0); + Value qv = builder.create(loc, veqArg, iv); + builder.create(loc, TypeRange{}, qv); + }); + rewriter.eraseOp(resetOp); + return success(); + } +}; + class ExpandMeasurementsPass : public cudaq::opt::ExpandMeasurementsBase { public: @@ -129,7 +153,7 @@ class ExpandMeasurementsPass auto *op = getOperation(); auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.insert(ctx); + patterns.insert(ctx); ConversionTarget target(*ctx); target.addLegalDialect(); @@ -139,6 +163,9 @@ class ExpandMeasurementsPass [](quake::MyOp x) { return usesIndividualQubit(x.getMeasOut()); }); target.addDynamicallyLegalOp( [](quake::MzOp x) { return usesIndividualQubit(x.getMeasOut()); }); + target.addDynamicallyLegalOp([](quake::ResetOp r) { + return !isa(r.getTargets().getType()); + }); if (failed(applyPartialConversion(op, target, std::move(patterns)))) { op->emitOpError("could not expand measurements"); signalPassFailure(); diff --git a/runtime/cudaq/builder/kernel_builder.cpp b/runtime/cudaq/builder/kernel_builder.cpp index c8149574d61..a37d2403ed2 100644 --- a/runtime/cudaq/builder/kernel_builder.cpp +++ b/runtime/cudaq/builder/kernel_builder.cpp @@ -792,30 +792,7 @@ QuakeValue mz(ImplicitLocOpBuilder &builder, QuakeValue &qubitOrQvec, } void reset(ImplicitLocOpBuilder &builder, const QuakeValue &qubitOrQvec) { - auto value = qubitOrQvec.getValue(); - if (isa(value.getType())) { - builder.create(TypeRange{}, value); - return; - } - - if (isa(value.getType())) { - auto target = value; - Value rank = builder.create(builder.getI64Type(), target); - auto bodyBuilder = [&](OpBuilder &builder, Location loc, Region &, - Block &block) { - Value ref = builder.create(loc, target, - block.getArgument(0)); - builder.create(loc, TypeRange{}, ref); - }; - cudaq::opt::factory::createInvariantLoop(builder, builder.getUnknownLoc(), - rank, bodyBuilder); - return; - } - - llvm::errs() << "Invalid type:\n"; - value.getType().dump(); - llvm::errs() << '\n'; - throw std::runtime_error("Invalid type passed to reset()."); + builder.create(TypeRange{}, qubitOrQvec.getValue()); } void swap(ImplicitLocOpBuilder &builder, const std::vector &ctrls,