Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added loop boundary optimization pass #1476

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/Quantum/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ std::unique_ptr<mlir::Pass> createDisentangleCNOTPass();
std::unique_ptr<mlir::Pass> createDisentangleSWAPPass();
std::unique_ptr<mlir::Pass> createIonsDecompositionPass();
std::unique_ptr<mlir::Pass> createStaticCustomLoweringPass();
std::unique_ptr<mlir::Pass> createLoopBoundaryOptimizationPass();

} // namespace catalyst
5 changes: 5 additions & 0 deletions mlir/include/Quantum/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def DisentangleSWAPPass : Pass<"disentangle-SWAP"> {
];
}

def LoopBoundaryOptimizationPass : Pass<"loop-boundary"> {
let summary = "Perform loop boundary optimization to eliminate the redundancy of operations on loop boundary.";

let constructor = "catalyst::createLoopBoundaryOptimizationPass()";
}
// ----- Quantum circuit transformation passes end ----- //

#endif // QUANTUM_PASSES
1 change: 1 addition & 0 deletions mlir/include/Quantum/Transforms/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void populateSelfInversePatterns(mlir::RewritePatternSet &);
void populateMergeRotationsPatterns(mlir::RewritePatternSet &);
void populateIonsDecompositionPatterns(mlir::RewritePatternSet &);
void populateStaticCustomPatterns(mlir::RewritePatternSet &);
void populateLoopBoundaryPatterns(mlir::RewritePatternSet &);

} // namespace quantum
} // namespace catalyst
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createTestPass);
mlir::registerPass(catalyst::createIonsDecompositionPass);
mlir::registerPass(catalyst::createQuantumToIonPass);
mlir::registerPass(catalyst::createLoopBoundaryOptimizationPass);
}
2 changes: 2 additions & 0 deletions mlir/lib/Quantum/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ file(GLOB SRC
IonsDecompositionPatterns.cpp
static_custom_lowering.cpp
StaticCustomPatterns.cpp
loop_boundary_optimization.cpp
LoopBoundaryOptimizationPatterns.cpp
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
Expand Down
130 changes: 130 additions & 0 deletions mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.
sengthai marked this conversation as resolved.
Show resolved Hide resolved

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#define DEBUG_TYPE "loop-boundary"

#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"
#include "VerifyParentGateAnalysis.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"

#include "mlir/IR/Operation.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

using llvm::dbgs;
using namespace mlir;
using namespace catalyst::quantum;

static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift",
"CRX", "CRY", "CRZ", "ControlledPhaseShift"};

static const mlir::StringSet<> hamiltonianSet = {"H", "X", "Y", "Z"};

namespace {

// TODO: Reduce the complexity of the function
// TODO: Support multi-qubit gates
template <typename OpType>
std::map<OpType, std::vector<mlir::Value>> traceOperationQubit(mlir::Block *block)
{
std::map<OpType, std::vector<mlir::Value>> opMap;
block->walk([&](OpType op) {
mlir::Value operand = op.getInQubits()[0];

while (auto definingOp = dyn_cast_or_null<CustomOp>(operand.getDefiningOp())) {
operand = definingOp.getInQubits()[0];
}

opMap[op].push_back(operand);
});
return opMap;
}

struct LoopBoundaryForLoopRewritePattern : public mlir::OpRewritePattern<scf::ForOp> {
using mlir::OpRewritePattern<scf::ForOp>::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(scf::ForOp forOp,
mlir::PatternRewriter &rewriter) const override
{
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << forOp << "\n");

// if number of operations in the loop is less than 2, return
if (forOp.getBody()->getOperations().size() < 2) {
return mlir::failure();
}

auto opMap = traceOperationQubit<quantum::CustomOp>(forOp.getBody());

quantum::CustomOp firstGateOp = opMap.begin()->first;
quantum::CustomOp secondGateOp = opMap.rbegin()->first;

if (opMap.begin()->first == opMap.rbegin()->first) {
return mlir::failure();
}

if (opMap[firstGateOp][0] == opMap[secondGateOp][0] &&
firstGateOp.getGateName() == secondGateOp.getGateName()) {

Check notice on line 82 in mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp

View check run for this annotation

codefactor.io / CodeFactor

mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp#L82

Redundant blank line at the start of a code block should be deleted. (whitespace/blank_line)
// create new top-edge gate
auto firstOp = firstGateOp.clone();
firstOp->setOperands(forOp.getInitArgs());
rewriter.setInsertionPoint(forOp);
rewriter.insert(firstOp);

// config the operand of for-loop to be the result of the first gate
forOp.setOperand(3, firstOp.getResult(0));

// config the successor of the first gate to be the for-loop
firstGateOp.getOutQubits().replaceAllUsesWith(firstGateOp.getInQubits());

// erase the first gate
rewriter.eraseOp(firstGateOp);

// create new bottom-edge gate
auto secondOp = secondGateOp.clone();

// replace successor of for-loop with the second gate
forOp.getResults().replaceAllUsesWith(secondOp);

// set the operands of the
secondOp->setOperands(forOp.getResults());
rewriter.setInsertionPointAfter(forOp);
rewriter.insert(secondOp);

// config the successor of the second gate to be the successor of the for-loop
secondGateOp.getOutQubits().replaceAllUsesWith(secondGateOp.getInQubits());
rewriter.eraseOp(secondGateOp);

return mlir::success();
}

Check notice on line 115 in mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp

View check run for this annotation

codefactor.io / CodeFactor

mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp#L115

Redundant blank line at the end of a code block should be deleted. (whitespace/blank_line)
}
};

} // namespace

namespace catalyst {
namespace quantum {

void populateLoopBoundaryPatterns(mlir::RewritePatternSet &patterns)
{
patterns.add<LoopBoundaryForLoopRewritePattern>(patterns.getContext(), 1);
}

} // namespace quantum
} // namespace catalyst
76 changes: 76 additions & 0 deletions mlir/lib/Quantum/Transforms/loop_boundary_optimization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.
sengthai marked this conversation as resolved.
Show resolved Hide resolved

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#define DEBUG_TYPE "loop-boundary"

#include "Catalyst/IR/CatalystDialect.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"

using namespace llvm;
using namespace mlir;
using namespace catalyst::quantum;

namespace catalyst {
namespace quantum {

#define GEN_PASS_DEF_LOOPBOUNDARYOPTIMIZATIONPASS
#define GEN_PASS_DECL_LOOPBOUNDARYOPTIMIZATIONPASS
#include "Quantum/Transforms/Passes.h.inc"

struct LoopBoundaryOptimizationPass
: impl::LoopBoundaryOptimizationPassBase<LoopBoundaryOptimizationPass> {
using LoopBoundaryOptimizationPassBase::LoopBoundaryOptimizationPassBase;

void runOnOperation() final
{
LLVM_DEBUG(dbgs() << "loop boundary optimization pass"
<< "\n");

Operation *module = getOperation();

RewritePatternSet patternsCanonicalization(&getContext());
scf::ForOp::getCanonicalizationPatterns(patternsCanonicalization, &getContext());

if (failed(applyPatternsAndFoldGreedily(module, std::move(patternsCanonicalization)))) {
return signalPassFailure();
}

RewritePatternSet patterns(&getContext());
populateLoopBoundaryPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace quantum

std::unique_ptr<Pass> createLoopBoundaryOptimizationPass()
{
return std::make_unique<quantum::LoopBoundaryOptimizationPass>();
}

} // namespace catalyst
39 changes: 39 additions & 0 deletions mlir/test/Quantum/LoopBoundaryOptimization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.
sengthai marked this conversation as resolved.
Show resolved Hide resolved

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// RUN: quantum-opt --loop-boundary --split-input-file -verify-diagnostics %s | FileCheck %s

func.func @test(%q: !quantum.bit) -> !quantum.bit {
%start = arith.constant 0 : index
%stop = arith.constant 10 : index
%step = arith.constant 1 : index
%phi = arith.constant 0.1 : f64

// CHECK: func @test([[arg:%.+]]: !quantum.bit) -> !quantum.bit {
// CHECK: [[qubit_0:%.+]] = quantum.custom "H"() [[arg]] : !quantum.bit

// CHECK: [[qubit_1:%.+]] = scf.for {{.*}} iter_args([[qubit_2:%.+]] = [[qubit_0]]) -> (!quantum.bit) {
%qq = scf.for %i = %start to %stop step %step iter_args(%q_0 = %q) -> (!quantum.bit) {
%q_1 = quantum.custom "H"() %q_0 : !quantum.bit
// CHECK: [[qubit_3:%.+]] = quantum.custom "RY"{{.*}} [[qubit_2]] : !quantum.bit
%q_2 = quantum.custom "RY"(%phi) %q_1 : !quantum.bit
%q_3 = quantum.custom "H"() %q_2 : !quantum.bit
// CHECK-NEXT: scf.yield [[qubit_3]] : !quantum.bit
scf.yield %q_3 : !quantum.bit
}

// CHECK: [[qubit_4:%.+]] = quantum.custom "H"() [[qubit_1]] : !quantum.bit
// CHECK: return [[qubit_4]]
func.return %qq : !quantum.bit
}
Loading