Skip to content

Commit

Permalink
State pointer synthesis for quantum hardware
Browse files Browse the repository at this point in the history
Signed-off-by: Anna Gringauze <[email protected]>
  • Loading branch information
annagrin committed Oct 17, 2024
1 parent 9327b05 commit 6a71755
Show file tree
Hide file tree
Showing 31 changed files with 955 additions and 51 deletions.
4 changes: 4 additions & 0 deletions include/cudaq/Optimizer/Builder/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ static constexpr const char createCudaqStateFromDataFP32[] =
// Delete a state created by the runtime functions above.
static constexpr const char deleteCudaqState[] = "__nvqpp_cudaq_state_delete";

// Get state of a kernel (placeholder function, calls are always replaced in
// opts)
static constexpr const char getCudaqState[] = "__nvqpp_cudaq_state_get";

/// Builder for lowering the clang AST to an IR for CUDA-Q. Lowering includes
/// the transformation of both quantum and classical computation. Different
/// features of the CUDA-Q programming model are lowered into different dialects
Expand Down
38 changes: 38 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,44 @@ def DeleteStates : Pass<"delete-states", "mlir::ModuleOp"> {
}];
}

def StateInitialization : Pass<"state-initialization", "mlir::ModuleOp"> {
let summary =
"Replace `quake.init_state` instructions with call to the kernel generating the state";
let description = [{
Argument synthesis for state pointers for quantum devices substitutes state
argument by a new state created from `__nvqpp_cudaq_state_get` intrinsic, which
in turn accepts the name for the synthesized kernel that generated the state.

This optimization completes the replacement of `quake.init_state` instruction by:

- Replace `quake.init_state` by a call that `get_state` call refers to.
- Remove all unneeded instructions.

For example:

Before StateInitialization (state-initialization):
```
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
%0 = cc.string_literal "__nvqpp__mlirgen__test_init_state.modified_0" : !cc.ptr<!cc.array<i8 x 45>>
%1 = cc.cast %0 : (!cc.ptr<!cc.array<i8 x 45>>) -> !cc.ptr<i8>
%2 = call @__nvqpp_cudaq_state_get(%1) : (!cc.ptr<i8>) -> !cc.ptr<!cc.state>
%3 = call @__nvqpp_cudaq_state_numberOfQubits(%2) : (!cc.ptr<!cc.state>) -> i64
%4 = quake.alloca !quake.veq<?>[%3 : i64]
%5 = quake.init_state %4, %2 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>
return
}
```

After StateInitialization (state-initialization):
```
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
%5 = call @__nvqpp__mlirgen__test_init_state.modified_0() : () -> !quake.veq<?>
return
}
```
}];
}

def StatePreparation : Pass<"state-prep", "mlir::ModuleOp"> {
let summary =
"Convert state vector data into gates";
Expand Down
4 changes: 4 additions & 0 deletions lib/Optimizer/Builder/Intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ static constexpr IntrinsicCode intrinsicTable[] = {

{cudaq::deleteCudaqState, {}, R"#(
func.func private @__nvqpp_cudaq_state_delete(%p : !cc.ptr<!cc.state>) -> ()
)#"},

{cudaq::getCudaqState, {}, R"#(
func.func private @__nvqpp_cudaq_state_get(%p : !cc.ptr<i8>) -> !cc.ptr<!cc.state>
)#"},

{cudaq::getNumQubitsFromCudaqState, {}, R"#(
Expand Down
3 changes: 2 additions & 1 deletion lib/Optimizer/CodeGen/VerifyNVQIRCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ struct VerifyNVQIRCallOpsPass
cudaq::getNumQubitsFromCudaqState,
cudaq::createCudaqStateFromDataFP32,
cudaq::createCudaqStateFromDataFP64,
cudaq::deleteCudaqState};
cudaq::deleteCudaqState,
cudaq::getCudaqState};
// It must be either NVQIR extension functions or in the allowed list.
return std::find(NVQIR_FUNCS.begin(), NVQIR_FUNCS.end(), functionName) !=
NVQIR_FUNCS.end() ||
Expand Down
1 change: 1 addition & 0 deletions lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ add_cudaq_library(OptTransforms
QuakeSynthesizer.cpp
RefToVeqAlloc.cpp
RegToMem.cpp
StateInitialization.cpp
StatePreparation.cpp
UnitarySynthesis.cpp
WiresToWiresets.cpp
Expand Down
11 changes: 7 additions & 4 deletions lib/Optimizer/Transforms/LiftArrayAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
if (auto load = dyn_cast<cudaq::cc::LoadOp>(useuser)) {
rewriter.setInsertionPointAfter(useuser);
LLVM_DEBUG(llvm::dbgs() << "replaced load\n");
rewriter.replaceOpWithNewOp<cudaq::cc::ExtractValueOp>(
load, eleTy, conArr,
ArrayRef<cudaq::cc::ExtractValueArg>{offset});
auto extract = rewriter.create<cudaq::cc::ExtractValueOp>(
loc, eleTy, conArr, ArrayRef<cudaq::cc::ExtractValueArg>{offset});
rewriter.replaceAllUsesWith(load, extract);
toErase.push_back(load);
continue;
}
if (isa<cudaq::cc::StoreOp>(useuser))
Expand All @@ -199,8 +200,10 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
toErase.push_back(alloc);
}

for (auto *op : toErase)
for (auto *op : toErase) {
op->dropAllUses();
rewriter.eraseOp(op);
}

return success();
}
Expand Down
146 changes: 146 additions & 0 deletions lib/Optimizer/Transforms/StateInitialization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*******************************************************************************
* Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include <span>

namespace cudaq::opt {
#define GEN_PASS_DEF_STATEINITIALIZATION
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt

#define DEBUG_TYPE "state-initialization"

using namespace mlir;

namespace {

static bool isCall(Operation *callOp, std::vector<const char *> &&names) {
if (callOp) {
if (auto createStateCall = dyn_cast<func::CallOp>(callOp)) {
if (auto calleeAttr = createStateCall.getCalleeAttr()) {
auto funcName = calleeAttr.getValue().str();
if (std::find(names.begin(), names.end(), funcName) != names.end())
return true;
}
}
}
return false;
}

static bool isGetStateCall(Operation *callOp) {
return isCall(callOp, {cudaq::getCudaqState});
}

static bool isNumberOfQubitsCall(Operation *callOp) {
return isCall(callOp, {cudaq::getNumQubitsFromCudaqState});
}

// clang-format off
/// Replace `quake.init_state` by a call to a (modified) kernel that produced the state.
/// ```
/// %0 = cc.string_literal "callee.modified_0" : !cc.ptr<!cc.array<i8 x 27>>
/// %1 = cc.cast %0 : (!cc.ptr<!cc.array<i8 x 27>>) -> !cc.ptr<i8>
/// %2 = call @__nvqpp_cudaq_state_get(%1) : (!cc.ptr<i8>) -> !cc.ptr<!cc.state>
/// %3 = call @__nvqpp_cudaq_state_numberOfQubits(%2) : (!cc.ptr<!cc.state>) -> i64
/// %4 = quake.alloca !quake.veq<?>[%3 : i64]
/// %5 = quake.init_state %4, %2 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>
/// ───────────────────────────────────────────
/// ...
/// %5 = call @callee.modified_0() : () -> !quake.veq<?>
/// ```
// clang-format on
class StateInitPattern : public OpRewritePattern<quake::InitializeStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::InitializeStateOp initState,
PatternRewriter &rewriter) const override {
auto loc = initState.getLoc();
auto allocaOp = initState.getOperand(0).getDefiningOp();
auto getStateOp = initState.getOperand(1).getDefiningOp();
auto numOfQubits = allocaOp->getOperand(0).getDefiningOp();

if (isGetStateCall(getStateOp)) {
auto calleeNameOp = getStateOp->getOperand(0);
if (auto cast =
dyn_cast<cudaq::cc::CastOp>(calleeNameOp.getDefiningOp())) {
calleeNameOp = cast.getOperand();

if (auto literal = dyn_cast<cudaq::cc::CreateStringLiteralOp>(
calleeNameOp.getDefiningOp())) {
auto calleeName = literal.getStringLiteral();

Value result =
rewriter
.create<func::CallOp>(loc, initState.getType(), calleeName,
mlir::ValueRange{})
.getResult(0);
rewriter.replaceAllUsesWith(initState, result);
initState.erase();
allocaOp->dropAllUses();
rewriter.eraseOp(allocaOp);
if (isNumberOfQubitsCall(numOfQubits)) {
numOfQubits->dropAllUses();
rewriter.eraseOp(numOfQubits);
}
getStateOp->dropAllUses();
rewriter.eraseOp(getStateOp);
cast->dropAllUses();
rewriter.eraseOp(cast);
literal->dropAllUses();
rewriter.eraseOp(literal);
return success();
}
}
}
return failure();
}
};

class StateInitializationPass
: public cudaq::opt::impl::StateInitializationBase<
StateInitializationPass> {
public:
using StateInitializationBase::StateInitializationBase;

void runOnOperation() override {
auto *ctx = &getContext();
auto module = getOperation();
for (Operation &op : *module.getBody()) {
auto func = dyn_cast<func::FuncOp>(op);
if (!func)
continue;

std::string funcName = func.getName().str();
RewritePatternSet patterns(ctx);
patterns.insert<StateInitPattern>(ctx);

LLVM_DEBUG(llvm::dbgs()
<< "Before state initialization: " << func << '\n');

if (failed(applyPatternsAndFoldGreedily(func.getOperation(),
std::move(patterns))))
signalPassFailure();

LLVM_DEBUG(llvm::dbgs()
<< "After state initialization: " << func << '\n');
}
}
};
} // namespace
5 changes: 3 additions & 2 deletions python/runtime/cudaq/algorithms/py_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ class PyRemoteSimulationState : public RemoteSimulationState {
}
}

std::pair<std::string, std::vector<void *>> getKernelInfo() const override {
return {kernelName, argsData->getArgs()};
std::optional<std::pair<std::string, std::vector<void *>>>
getKernelInfo() const override {
return std::make_pair(kernelName, argsData->getArgs());
}

std::complex<double> overlap(const cudaq::SimulationState &other) override {
Expand Down
2 changes: 1 addition & 1 deletion python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ MlirModule synthesizeKernel(const std::string &name, MlirModule module,
auto isLocalSimulator = platform.is_simulator() && !platform.is_emulated();
auto isSimulator = isLocalSimulator || isRemoteSimulator;

cudaq::opt::ArgumentConverter argCon(name, unwrap(module), isSimulator);
cudaq::opt::ArgumentConverter argCon(name, unwrap(module));
argCon.gen(runtimeArgs.getArgs());
std::string kernName = cudaq::runtime::cudaqGenPrefixName + name;
SmallVector<StringRef> kernels = {kernName};
Expand Down
Loading

0 comments on commit 6a71755

Please sign in to comment.