Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State synthesis for quantum devices #2291

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ac01dd1
DCO Remediation Commit for Ben Howe <[email protected]>
bmhowe23 Oct 11, 2024
21a87c1
State pointer synthesis for quantum hardware
annagrin Sep 17, 2024
3fc56de
Merge with main
annagrin Oct 17, 2024
7969a75
Merge with main
annagrin Oct 17, 2024
755d0d1
Fix test failure on anyon platform
annagrin Oct 17, 2024
dc5e77e
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 17, 2024
382bc99
Make StateInitialization a funcOp pass
annagrin Oct 17, 2024
d3a05d4
Fix issues and tests for the rest of quantum architectures
annagrin Oct 18, 2024
ac151f2
Merge with main
annagrin Oct 18, 2024
51ef054
Fix failing quantinuum state prep tests
annagrin Oct 18, 2024
0cdf3e9
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 18, 2024
5307aa4
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 21, 2024
a7f5387
Address CR comments
annagrin Oct 21, 2024
eb8db13
Merge with main
annagrin Oct 21, 2024
9f0937f
Format
annagrin Oct 21, 2024
2f3a623
Fix failing test
annagrin Oct 22, 2024
b381350
Format
annagrin Oct 22, 2024
dc87ca4
Format
annagrin Oct 22, 2024
e4c7735
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 22, 2024
53a34c9
Replaced getState intrinsic by cc.get_state op
annagrin Oct 22, 2024
30777f3
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 22, 2024
fe6d409
Remove print
annagrin Oct 22, 2024
48704e3
Remove getCudaqState references
annagrin Oct 22, 2024
137f621
Minor updates
annagrin Oct 22, 2024
ad7c6bc
Fix failing quake test
annagrin Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/cudaq/Optimizer/Builder/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ static constexpr const char createCudaqStateFromDataFP32[] =
// Delete a state created by the runtime functions above.
static constexpr const char deleteCudaqState[] = "__nvqpp_cudaq_state_delete";

// Get state of a kernel (placeholder function, calls are always replaced in
// opts)
static constexpr const char getCudaqState[] = "__nvqpp_cudaq_state_get";
annagrin marked this conversation as resolved.
Show resolved Hide resolved

/// Builder for lowering the clang AST to an IR for CUDA-Q. Lowering includes
/// the transformation of both quantum and classical computation. Different
/// features of the CUDA-Q programming model are lowered into different dialects
Expand Down
49 changes: 49 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,44 @@ def DeleteStates : Pass<"delete-states", "mlir::ModuleOp"> {
}];
}

def StateInitialization : Pass<"state-initialization", "mlir::func::FuncOp"> {
annagrin marked this conversation as resolved.
Show resolved Hide resolved
let summary =
"Replace `quake.init_state` instructions with call to the kernel generating the state";
let description = [{
Argument synthesis for state pointers for quantum devices substitutes state
argument by a new state created from `__nvqpp_cudaq_state_get` intrinsic, which
in turn accepts the name for the synthesized kernel that generated the state.

This optimization completes the replacement of `quake.init_state` instruction by:

- Replace `quake.init_state` by a call that `get_state` call refers to.
- Remove all unneeded instructions.

For example:

Before StateInitialization (state-initialization):
```
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
%0 = cc.string_literal "__nvqpp__mlirgen__test_init_state.modified_0" : !cc.ptr<!cc.array<i8 x 45>>
%1 = cc.cast %0 : (!cc.ptr<!cc.array<i8 x 45>>) -> !cc.ptr<i8>
%2 = call @__nvqpp_cudaq_state_get(%1) : (!cc.ptr<i8>) -> !cc.ptr<!cc.state>
%3 = call @__nvqpp_cudaq_state_numberOfQubits(%2) : (!cc.ptr<!cc.state>) -> i64
%4 = quake.alloca !quake.veq<?>[%3 : i64]
%5 = quake.init_state %4, %2 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>
return
}
```

After StateInitialization (state-initialization):
```
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
%5 = call @__nvqpp__mlirgen__test_init_state.modified_0() : () -> !quake.veq<?>
return
}
```
}];
}

def StatePreparation : Pass<"state-prep", "mlir::ModuleOp"> {
let summary =
"Convert state vector data into gates";
Expand Down Expand Up @@ -828,6 +866,17 @@ def StatePreparation : Pass<"state-prep", "mlir::ModuleOp"> {
];
}

def StateValidation : Pass<"state-validation", "mlir::ModuleOp"> {
annagrin marked this conversation as resolved.
Show resolved Hide resolved
let summary =
"Make sure MLIR is valid after synthesis for quantum devices";
let description = [{
Argument synthesis should replace all `quake.init` from state instructions
annagrin marked this conversation as resolved.
Show resolved Hide resolved
and calls to state-related runtime functions.
Make sure none of them left, and remove definitions for state-related
annagrin marked this conversation as resolved.
Show resolved Hide resolved
runtime functions.
}];
}

def PromoteRefToVeqAlloc : Pass<"promote-qubit-allocation"> {
let summary = "Promote single qubit allocations.";
let description = [{
Expand Down
4 changes: 4 additions & 0 deletions lib/Optimizer/Builder/Intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ static constexpr IntrinsicCode intrinsicTable[] = {

{cudaq::deleteCudaqState, {}, R"#(
func.func private @__nvqpp_cudaq_state_delete(%p : !cc.ptr<!cc.state>) -> ()
)#"},

{cudaq::getCudaqState, {}, R"#(
func.func private @__nvqpp_cudaq_state_get(%p : !cc.ptr<i8>) -> !cc.ptr<!cc.state>
)#"},

{cudaq::getNumQubitsFromCudaqState, {}, R"#(
Expand Down
3 changes: 2 additions & 1 deletion lib/Optimizer/CodeGen/VerifyNVQIRCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ struct VerifyNVQIRCallOpsPass
cudaq::getNumQubitsFromCudaqState,
cudaq::createCudaqStateFromDataFP32,
cudaq::createCudaqStateFromDataFP64,
cudaq::deleteCudaqState};
cudaq::deleteCudaqState,
cudaq::getCudaqState};
// It must be either NVQIR extension functions or in the allowed list.
return std::find(NVQIR_FUNCS.begin(), NVQIR_FUNCS.end(), functionName) !=
NVQIR_FUNCS.end() ||
Expand Down
2 changes: 2 additions & 0 deletions lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ add_cudaq_library(OptTransforms
QuakeSynthesizer.cpp
RefToVeqAlloc.cpp
RegToMem.cpp
StateInitialization.cpp
StatePreparation.cpp
StateValidation.cpp
UnitarySynthesis.cpp
WiresToWiresets.cpp

Expand Down
141 changes: 141 additions & 0 deletions lib/Optimizer/Transforms/StateInitialization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*******************************************************************************
* Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include <span>

namespace cudaq::opt {
#define GEN_PASS_DEF_STATEINITIALIZATION
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt

#define DEBUG_TYPE "state-initialization"

using namespace mlir;

namespace {

static bool isCall(Operation *op, std::vector<const char *> &&names) {
if (op) {
if (auto callOp = dyn_cast<func::CallOp>(op)) {
if (auto calleeAttr = callOp.getCalleeAttr()) {
auto funcName = calleeAttr.getValue().str();
if (std::find(names.begin(), names.end(), funcName) != names.end())
return true;
annagrin marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
return false;
}

static bool isGetStateCall(Operation *op) {
return isCall(op, {cudaq::getCudaqState});
}

static bool isNumberOfQubitsCall(Operation *op) {
return isCall(op, {cudaq::getNumQubitsFromCudaqState});
}

// clang-format off
/// Replace `quake.init_state` by a call to a (modified) kernel that produced the state.
annagrin marked this conversation as resolved.
Show resolved Hide resolved
/// ```
/// %0 = cc.string_literal "callee.modified_0" : !cc.ptr<!cc.array<i8 x 27>>
/// %1 = cc.cast %0 : (!cc.ptr<!cc.array<i8 x 27>>) -> !cc.ptr<i8>
/// %2 = call @__nvqpp_cudaq_state_get(%1) : (!cc.ptr<i8>) -> !cc.ptr<!cc.state>
/// %3 = call @__nvqpp_cudaq_state_numberOfQubits(%2) : (!cc.ptr<!cc.state>) -> i64
/// %4 = quake.alloca !quake.veq<?>[%3 : i64]
/// %5 = quake.init_state %4, %2 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>
/// ───────────────────────────────────────────
/// ...
/// %5 = call @callee.modified_0() : () -> !quake.veq<?>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This violates the semantics of Quake: all quantum memory allocations happen at the top-level and cannot be returned from other kernels.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is only temporary and does not seem to break anything until the Inlining fixes the semantics... But returning allocations (instead of passing them as a parameter) is much easier to implement. Let me know if it is a showstopper, I can revisit passing allocations as a parameter.

Copy link
Collaborator

@schweitzpgi schweitzpgi Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's a bug even if it "works" for now, so I think we ought to deal with it now.

I don't see why returning is any easier. Updating the "init kernel" from something like

func.func @init_kernel() {
   %0 = quake.alloca !quake.veq<4>
   ...
}

to a "modified init kernel" of

func.func @init_kernel.modified(%0 : !quake.veq<4>) {
   ...
}

is easy enough.

Also, if we pass in the allocated qubits, we don't have to erase them at the top-level or rely on inlining to get things normalized, etc.

Copy link
Collaborator Author

@annagrin annagrin Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callee can have multiple allocations, so I would need to know how to split the qvector to pass to those allocations. Not sure how to do that yet (process the function to get allocation sizes first? The size can be a result of a _getNumberOfQubits call though...)

I need time to try changing to passing allocations as parameters, probably won't fit into this release, unless I only support a trivial case of one allocation in a callee kernel. Do you think this part can be done in a later PR?

/// ```
// clang-format on
class StateInitPattern : public OpRewritePattern<quake::InitializeStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::InitializeStateOp initState,
PatternRewriter &rewriter) const override {
auto loc = initState.getLoc();
auto allocaOp = initState.getOperand(0).getDefiningOp();
annagrin marked this conversation as resolved.
Show resolved Hide resolved
auto stateOp = initState.getOperand(1);

if (isa<cudaq::cc::StateType>(stateOp.getType())) {
auto getStateOp = stateOp.getDefiningOp();
annagrin marked this conversation as resolved.
Show resolved Hide resolved
auto numOfQubits = allocaOp->getOperand(0).getDefiningOp();
annagrin marked this conversation as resolved.
Show resolved Hide resolved

if (isGetStateCall(getStateOp)) {
auto calleeNameOp = getStateOp->getOperand(0);
if (auto cast =
dyn_cast<cudaq::cc::CastOp>(calleeNameOp.getDefiningOp())) {
annagrin marked this conversation as resolved.
Show resolved Hide resolved
calleeNameOp = cast.getOperand();

if (auto literal = dyn_cast<cudaq::cc::CreateStringLiteralOp>(
calleeNameOp.getDefiningOp())) {
annagrin marked this conversation as resolved.
Show resolved Hide resolved
auto calleeName = literal.getStringLiteral();

Value result =
rewriter
.create<func::CallOp>(loc, initState.getType(), calleeName,
mlir::ValueRange{})
.getResult(0);
rewriter.replaceAllUsesWith(initState, result);
annagrin marked this conversation as resolved.
Show resolved Hide resolved
initState.erase();
allocaOp->dropAllUses();
rewriter.eraseOp(allocaOp);
annagrin marked this conversation as resolved.
Show resolved Hide resolved
if (isNumberOfQubitsCall(numOfQubits)) {
numOfQubits->dropAllUses();
rewriter.eraseOp(numOfQubits);
}
getStateOp->dropAllUses();
rewriter.eraseOp(getStateOp);
cast->dropAllUses();
rewriter.eraseOp(cast);
literal->dropAllUses();
rewriter.eraseOp(literal);
return success();
}
}
}
}
return failure();
annagrin marked this conversation as resolved.
Show resolved Hide resolved
}
};

class StateInitializationPass
: public cudaq::opt::impl::StateInitializationBase<
StateInitializationPass> {
public:
using StateInitializationBase::StateInitializationBase;

void runOnOperation() override {
auto *ctx = &getContext();
auto func = getOperation();
RewritePatternSet patterns(ctx);
patterns.insert<StateInitPattern>(ctx);

LLVM_DEBUG(llvm::dbgs() << "Before state initialization: " << func << '\n');

if (failed(applyPatternsAndFoldGreedily(func.getOperation(),
std::move(patterns))))
signalPassFailure();

LLVM_DEBUG(llvm::dbgs() << "After state initialization: " << func << '\n');
}
};
} // namespace
127 changes: 127 additions & 0 deletions lib/Optimizer/Transforms/StateValidation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*******************************************************************************
* Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

namespace cudaq::opt {
#define GEN_PASS_DEF_STATEVALIDATION
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt

#define DEBUG_TYPE "state-validation"

using namespace mlir;

/// Validate that quantum code does not contain runtime calls and remove runtime
/// function definitions.
namespace {

static bool isRuntimeStateCallName(llvm::StringRef funcName) {
static std::vector<const char *> names = {
cudaq::getCudaqState, cudaq::createCudaqStateFromDataFP32,
cudaq::createCudaqStateFromDataFP64, cudaq::deleteCudaqState,
cudaq::getNumQubitsFromCudaqState};
if (std::find(names.begin(), names.end(), funcName) != names.end())
return true;
return false;
}

static bool isRuntimeStateCall(Operation *callOp) {
if (callOp) {
if (auto call = dyn_cast<func::CallOp>(callOp)) {
if (auto calleeAttr = call.getCalleeAttr()) {
auto funcName = calleeAttr.getValue().str();
if (isRuntimeStateCallName(funcName))
return true;
}
}
}
return false;
}

class ValidateStateCallPattern : public OpRewritePattern<func::CallOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(func::CallOp callOp,
PatternRewriter &rewriter) const override {
if (isRuntimeStateCall(callOp)) {
auto name = callOp.getCalleeAttr().getValue();
callOp.emitError(
"Synthesis did not remove func call for quantum platform: " + name);
}
return failure();
}
};

class ValidateStateInitPattern
: public OpRewritePattern<quake::InitializeStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::InitializeStateOp initState,
PatternRewriter &rewriter) const override {
auto stateOp = initState.getOperand(1);
if (isa<cudaq::cc::StateType>(stateOp.getType()))
initState.emitError("Synthesis did not remove `quake.init_state <veq> "
"<state>` instruction");

return failure();
}
};

class StateValidationPass
annagrin marked this conversation as resolved.
Show resolved Hide resolved
: public cudaq::opt::impl::StateValidationBase<StateValidationPass> {
protected:
public:
using StateValidationBase::StateValidationBase;

mlir::ModuleOp getModule() { return getOperation(); }

void runOnOperation() override final {
auto *ctx = &getContext();
auto module = getModule();
SmallVector<Operation *> toErase;

for (Operation &op : *module.getBody()) {
auto func = dyn_cast<func::FuncOp>(op);
if (!func)
continue;

RewritePatternSet patterns(ctx);
patterns.insert<ValidateStateCallPattern, ValidateStateInitPattern>(ctx);

LLVM_DEBUG(llvm::dbgs() << "Before state validation: " << func << '\n');

if (failed(applyPatternsAndFoldGreedily(func.getOperation(),
std::move(patterns))))
signalPassFailure();

// Delete runtime function definitions.
if (func.getBody().empty() && isRuntimeStateCallName(func.getName()))
toErase.push_back(func);

LLVM_DEBUG(llvm::dbgs() << "After state validation: " << func << '\n');
}

for (auto *op : toErase)
op->erase();
}
};

} // namespace
5 changes: 3 additions & 2 deletions python/runtime/cudaq/algorithms/py_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ class PyRemoteSimulationState : public RemoteSimulationState {
}
}

std::pair<std::string, std::vector<void *>> getKernelInfo() const override {
return {kernelName, argsData->getArgs()};
std::optional<std::pair<std::string, std::vector<void *>>>
getKernelInfo() const override {
return std::make_pair(kernelName, argsData->getArgs());
}

std::complex<double> overlap(const cudaq::SimulationState &other) override {
Expand Down
2 changes: 1 addition & 1 deletion python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ MlirModule synthesizeKernel(const std::string &name, MlirModule module,
auto isLocalSimulator = platform.is_simulator() && !platform.is_emulated();
auto isSimulator = isLocalSimulator || isRemoteSimulator;

cudaq::opt::ArgumentConverter argCon(name, unwrap(module), isSimulator);
cudaq::opt::ArgumentConverter argCon(name, unwrap(module));
argCon.gen(runtimeArgs.getArgs());
std::string kernName = cudaq::runtime::cudaqGenPrefixName + name;
SmallVector<StringRef> kernels = {kernName};
Expand Down
Loading
Loading