From 5b4059b50091b3d48d7160cb140478955cbdc752 Mon Sep 17 00:00:00 2001 From: boschmitt <7152025+boschmitt@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:45:51 +0100 Subject: [PATCH] [python] Remove unused MLIR components We don't need to take everything from MLIR for our python bindings. This change cherry picks the upstream components our compiler depends on. The commit also cleans up some unnecessary code that ends up registering dialects more than once, and surfaces the `register_all_dialects` function to a less obscure location. Signed-off-by: boschmitt <7152025+boschmitt@users.noreply.github.com> --- python/cudaq/kernel/ast_bridge.py | 2 - python/cudaq/kernel/kernel_builder.py | 5 +- python/cudaq/mlir/__init__.py | 9 +++ python/extension/CMakeLists.txt | 22 ++++--- python/runtime/mlir/py_register_dialects.cpp | 65 ++++---------------- python/tests/mlir/bare.py | 3 +- 6 files changed, 41 insertions(+), 65 deletions(-) create mode 100644 python/cudaq/mlir/__init__.py diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index 5f566768c2..2726b0da92 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs): else: self.ctx = Context() register_all_dialects(self.ctx) - quake.register_dialect(self.ctx) - cc.register_dialect(self.ctx) cudaq_runtime.registerLLVMDialectTranslation(self.ctx) self.loc = Location.unknown(context=self.ctx) self.module = Module.create(loc=self.loc) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index 192aad8929..f587a01eb2 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -37,7 +37,8 @@ # We need static initializers to run in the CAPI `ExecutionEngine`, # so here we run a simple JIT compile at global scope -with Context(): +with Context() as ctx: + register_all_dialects(ctx) module = Module.parse(r""" llvm.func @none() { llvm.return @@ -246,8 +247,6 @@ class PyKernel(object): def __init__(self, argTypeList): self.ctx = Context() register_all_dialects(self.ctx) - quake.register_dialect(self.ctx) - cc.register_dialect(self.ctx) cudaq_runtime.registerLLVMDialectTranslation(self.ctx) self.metadata = {'conditionalOnMeasure': False} diff --git a/python/cudaq/mlir/__init__.py b/python/cudaq/mlir/__init__.py new file mode 100644 index 0000000000..eda2e6614f --- /dev/null +++ b/python/cudaq/mlir/__init__.py @@ -0,0 +1,9 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +from ._mlir_libs._quakeDialects import register_all_dialects diff --git a/python/extension/CMakeLists.txt b/python/extension/CMakeLists.txt index 021cb6da77..d2886dec54 100644 --- a/python/extension/CMakeLists.txt +++ b/python/extension/CMakeLists.txt @@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI RELATIVE_INSTALL_ROOT "../.." DECLARED_SOURCES CUDAQuantumPythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything MLIRPythonSources.Core + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.complex + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.math + MLIRPythonSources.ExecutionEngine ) ################################################################################ @@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules INSTALL_PREFIX "cudaq/mlir" DECLARED_SOURCES CUDAQuantumPythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything - MLIRPythonSources + MLIRPythonSources.Core + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.complex + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.math + MLIRPythonSources.ExecutionEngine COMMON_CAPI_LINK_LIBS CUDAQuantumPythonCAPI ) diff --git a/python/runtime/mlir/py_register_dialects.cpp b/python/runtime/mlir/py_register_dialects.cpp index 9c0c4f2985..c579298f57 100644 --- a/python/runtime/mlir/py_register_dialects.cpp +++ b/python/runtime/mlir/py_register_dialects.cpp @@ -6,19 +6,13 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ -#include "mlir/Bindings/Python/PybindAdaptors.h" - #include "cudaq/Optimizer/Builder/Intrinsics.h" -#include "cudaq/Optimizer/CAPI/Dialects.h" -#include "cudaq/Optimizer/CodeGen/Passes.h" -#include "cudaq/Optimizer/CodeGen/Pipelines.h" -#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" -#include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/CC/CCTypes.h" -#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h" -#include "cudaq/Optimizer/Transforms/Passes.h" -#include "mlir/InitAllDialects.h" +#include "cudaq/Optimizer/InitAllDialects.h" +#include "cudaq/Optimizer/InitAllPasses.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/IR.h" #include #include #include @@ -28,32 +22,10 @@ using namespace mlir::python::adaptors; using namespace mlir; namespace cudaq { -static bool registered = false; -void registerQuakeDialectAndTypes(py::module &m) { +void registerQuakeTypes(py::module &m) { auto quakeMod = m.def_submodule("quake"); - quakeMod.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle handle = mlirGetDialectHandle__quake__(); - mlirDialectHandleRegisterDialect(handle, context); - if (load) { - mlirDialectHandleLoadDialect(handle, context); - } - - if (!registered) { - cudaq::opt::registerOptCodeGenPasses(); - cudaq::opt::registerOptTransformsPasses(); - cudaq::opt::registerAggressiveEarlyInlining(); - cudaq::opt::registerUnrollingPipeline(); - cudaq::opt::registerTargetPipelines(); - cudaq::opt::registerMappingPipeline(); - registered = true; - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - mlir_type_subclass(quakeMod, "RefType", [](MlirType type) { return unwrap(type).isa(); }).def_classmethod("get", [](py::object cls, MlirContext ctx) { @@ -143,21 +115,10 @@ void registerQuakeDialectAndTypes(py::module &m) { }); } -void registerCCDialectAndTypes(py::module &m) { +void registerCCTypes(py::module &m) { auto ccMod = m.def_submodule("cc"); - ccMod.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__(); - mlirDialectHandleRegisterDialect(ccHandle, context); - if (load) { - mlirDialectHandleLoadDialect(ccHandle, context); - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) { return unwrap(type).isa(); }).def_classmethod("get", [](py::object cls, MlirContext ctx) { @@ -298,9 +259,6 @@ void registerCCDialectAndTypes(py::module &m) { } void bindRegisterDialects(py::module &mod) { - registerQuakeDialectAndTypes(mod); - registerCCDialectAndTypes(mod); - mod.def("load_intrinsic", [](MlirModule module, std::string name) { auto unwrapped = unwrap(module); cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody()); @@ -310,14 +268,17 @@ void bindRegisterDialects(py::module &mod) { mod.def("register_all_dialects", [](MlirContext context) { DialectRegistry registry; - registry.insert(); - cudaq::opt::registerCodeGenDialect(registry); - registerAllDialects(registry); - auto *mlirContext = unwrap(context); + cudaq::registerAllDialects(registry); + MLIRContext *mlirContext = unwrap(context); mlirContext->appendDialectRegistry(registry); mlirContext->loadAllAvailableDialects(); }); + // Register type as passes once, when the module is loaded. + registerQuakeTypes(mod); + registerCCTypes(mod); + cudaq::registerAllPasses(); + mod.def("gen_vector_of_complex_constant", [](MlirLocation loc, MlirModule module, std::string name, diff --git a/python/tests/mlir/bare.py b/python/tests/mlir/bare.py index 331bf40013..6320279763 100644 --- a/python/tests/mlir/bare.py +++ b/python/tests/mlir/bare.py @@ -8,12 +8,13 @@ # RUN: PYTHONPATH=../../ python3 %s | FileCheck %s +from cudaq.mlir import register_all_dialects from cudaq.mlir.ir import * from cudaq.mlir.dialects import quake from cudaq.mlir.dialects import builtin, func, arith with Context() as ctx: - quake.register_dialect() + register_all_dialects(ctx) m = Module.create(loc=Location.unknown()) with InsertionPoint(m.body), Location.unknown(): f = func.FuncOp('main', ([], []))