Skip to content

Commit

Permalink
[kernel_builder] Use quake.reset with a veq argument.
Browse files Browse the repository at this point in the history
Simplify the kernel_builder to use both qubit and veq forms of the reset
op. Add a pattern to expand measurements that also expands resets.

Requires #1910 to be merged first.
  • Loading branch information
schweitzpgi committed Jul 12, 2024
1 parent 041cc76 commit 02ac2f2
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 27 deletions.
3 changes: 3 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def ExpandMeasurements : Pass<"expand-measurements"> {
The target may only support measuring a single qubit however. This pass
expands these ops in list format into a series of measurements (including
loops) on individual qubits and into a single `std::vector<bool>` result.

The `reset` op can also take a veq argument and this pass will also expand
that to a series of `reset` operations on single qubits.
}];

let dependentDialects = ["cudaq::cc::CCDialect", "mlir::LLVM::LLVMDialect"];
Expand Down
3 changes: 1 addition & 2 deletions lib/Optimizer/CodeGen/QuakeToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,10 @@ class ResetRewrite : public ConvertOpToLLVMPattern<quake::ResetOp> {
ConversionPatternRewriter &rewriter) const override {
auto parentModule = instOp->getParentOfType<ModuleOp>();
auto context = parentModule->getContext();
std::string qirQisPrefix{cudaq::opt::QIRQISPrefix};
std::string instName = instOp->getName().stripDialect().str();

// Get the reset QIR function name
auto qirFunctionName = qirQisPrefix + instName;
auto qirFunctionName = cudaq::opt::QIRQISPrefix + instName;

// Create the qubit pointer type
auto qirQubitPointerType = cudaq::opt::getQubitType(context);
Expand Down
29 changes: 28 additions & 1 deletion lib/Optimizer/Transforms/ExpandMeasurements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,38 @@ using MxRewrite = ExpandRewritePattern<quake::MxOp>;
using MyRewrite = ExpandRewritePattern<quake::MyOp>;
using MzRewrite = ExpandRewritePattern<quake::MzOp>;

/// Convert a `quake.reset` with a `veq` argument into a loop over the elements
/// of the `veq` and `quake.reset` on each of them.
class ResetRewrite : public OpRewritePattern<quake::ResetOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::ResetOp resetOp,
PatternRewriter &rewriter) const override {
auto loc = resetOp.getLoc();
auto veqArg = resetOp.getTargets();
auto i64Ty = rewriter.getI64Type();
Value vecSz = rewriter.create<quake::VeqSizeOp>(loc, i64Ty, veqArg);
cudaq::opt::factory::createInvariantLoop(
rewriter, loc, vecSz,
[&](OpBuilder &builder, Location loc, Region &, Block &block) {
Value iv = block.getArgument(0);
Value qv = builder.create<quake::ExtractRefOp>(loc, veqArg, iv);
builder.create<quake::ResetOp>(loc, TypeRange{}, qv);
});
rewriter.eraseOp(resetOp);
return success();
}
};

class ExpandMeasurementsPass
: public cudaq::opt::ExpandMeasurementsBase<ExpandMeasurementsPass> {
public:
void runOnOperation() override {
auto *op = getOperation();
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.insert<MxRewrite, MyRewrite, MzRewrite>(ctx);
patterns.insert<MxRewrite, MyRewrite, MzRewrite, ResetRewrite>(ctx);
ConversionTarget target(*ctx);
target.addLegalDialect<quake::QuakeDialect, cudaq::cc::CCDialect,
arith::ArithDialect, LLVM::LLVMDialect>();
Expand All @@ -139,6 +163,9 @@ class ExpandMeasurementsPass
[](quake::MyOp x) { return usesIndividualQubit(x.getMeasOut()); });
target.addDynamicallyLegalOp<quake::MzOp>(
[](quake::MzOp x) { return usesIndividualQubit(x.getMeasOut()); });
target.addDynamicallyLegalOp<quake::ResetOp>([](quake::ResetOp r) {
return !isa<quake::VeqType>(r.getTargets().getType());
});
if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
op->emitOpError("could not expand measurements");
signalPassFailure();
Expand Down
25 changes: 1 addition & 24 deletions runtime/cudaq/builder/kernel_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,30 +792,7 @@ QuakeValue mz(ImplicitLocOpBuilder &builder, QuakeValue &qubitOrQvec,
}

void reset(ImplicitLocOpBuilder &builder, const QuakeValue &qubitOrQvec) {
auto value = qubitOrQvec.getValue();
if (isa<quake::RefType>(value.getType())) {
builder.create<quake::ResetOp>(TypeRange{}, value);
return;
}

if (isa<quake::VeqType>(value.getType())) {
auto target = value;
Value rank = builder.create<quake::VeqSizeOp>(builder.getI64Type(), target);
auto bodyBuilder = [&](OpBuilder &builder, Location loc, Region &,
Block &block) {
Value ref = builder.create<quake::ExtractRefOp>(loc, target,
block.getArgument(0));
builder.create<quake::ResetOp>(loc, TypeRange{}, ref);
};
cudaq::opt::factory::createInvariantLoop(builder, builder.getUnknownLoc(),
rank, bodyBuilder);
return;
}

llvm::errs() << "Invalid type:\n";
value.getType().dump();
llvm::errs() << '\n';
throw std::runtime_error("Invalid type passed to reset().");
builder.create<quake::ResetOp>(TypeRange{}, qubitOrQvec.getValue());
}

void swap(ImplicitLocOpBuilder &builder, const std::vector<QuakeValue> &ctrls,
Expand Down

0 comments on commit 02ac2f2

Please sign in to comment.