From 762b9a749f40c4b77659cb5213297ecb92d0f120 Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Thu, 4 Jan 2024 15:21:53 -0500 Subject: [PATCH] Use dpjit target context --- numba_dpex/core/dpjit_dispatcher.py | 63 ++++++++++++++++--- .../core/passes/parfor_legalize_cfd_pass.py | 16 ++++- numba_dpex/core/targets/dpjit_target.py | 27 ++++---- numba_dpex/decorators.py | 10 ++- 4 files changed, 83 insertions(+), 33 deletions(-) diff --git a/numba_dpex/core/dpjit_dispatcher.py b/numba_dpex/core/dpjit_dispatcher.py index 74beb37082..c6bdb6b942 100644 --- a/numba_dpex/core/dpjit_dispatcher.py +++ b/numba_dpex/core/dpjit_dispatcher.py @@ -2,14 +2,44 @@ # # SPDX-License-Identifier: Apache-2.0 -from numba.core import compiler, dispatcher -from numba.core.target_extension import dispatcher_registry, target_registry +from numba.core import dispatcher, errors +from numba.core.target_extension import ( + dispatcher_registry, + target_override, + target_registry, +) +from numba_dpex import numba_sem_version +from numba_dpex.core.pipelines import dpjit_compiler from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME from .descriptor import dpex_target +class _DpjitCompiler(dispatcher._FunctionCompiler): + """A special compiler class used to compile numba_dpex.dpjit decorated + functions. + """ + + def _compile_cached(self, args, return_type): + # follows the same logic as original one, but triggers _compile_core + # with dpex target overload. + key = tuple(args), return_type + try: + return False, self._failed_cache[key] + except KeyError: + pass + + try: + with target_override(DPEX_TARGET_NAME): + retval = self._compile_core(args, return_type) + except errors.TypingError as e: + self._failed_cache[key] = e + return False, e + else: + return True, retval + + class DpjitDispatcher(dispatcher.Dispatcher): """A dpex.djit-specific dispatcher. @@ -26,16 +56,29 @@ def __init__( py_func, locals={}, targetoptions={}, - impl_kind="direct", - pipeline_class=compiler.Compiler, + pipeline_class=dpjit_compiler.DpjitCompiler, ): - dispatcher.Dispatcher.__init__( - self, + if numba_sem_version < (0, 59, 0): + super().__init__( + py_func=py_func, + locals=locals, + impl_kind="direct", + targetoptions=targetoptions, + pipeline_class=pipeline_class, + ) + else: + super().__init__( + py_func=py_func, + locals=locals, + targetoptions=targetoptions, + pipeline_class=pipeline_class, + ) + self._compiler = _DpjitCompiler( py_func, - locals=locals, - targetoptions=targetoptions, - impl_kind=impl_kind, - pipeline_class=pipeline_class, + self.targetdescr, + targetoptions, + locals, + pipeline_class, ) diff --git a/numba_dpex/core/passes/parfor_legalize_cfd_pass.py b/numba_dpex/core/passes/parfor_legalize_cfd_pass.py index 6c5a10760b..81cabf9c82 100644 --- a/numba_dpex/core/passes/parfor_legalize_cfd_pass.py +++ b/numba_dpex/core/passes/parfor_legalize_cfd_pass.py @@ -42,6 +42,8 @@ class ParforLegalizeCFDPassImpl: """ + # TODO: fix point algorithm implementation + inputUsmTypeStrToInt = {"device": 3, "shared": 2, "host": 1} inputUsmTypeIntToStr = {3: "device", 2: "shared", 1: "host"} @@ -59,7 +61,14 @@ def _check_if_dpnp_empty_call(self, call_stmt, block): and isinstance(func_def.value, ir.Expr) and func_def.value.op == "getattr" ): - raise AssertionError + # TODO: write unit test that this check passes for the instruction + # generated by dpnp.empty() after type inferring pass + # Possible implementation: + # 1. dpjit with dpnp.empty() + # 2. generate Numba IR after type inference for it by using dummy + # pipeline + # 3. check that we have assertion making place + return False module_name = block.find_variable_assignment( func_def.value.list_vars()[0].name @@ -162,6 +171,7 @@ def _legalize_parfor_params(self, parfor): str: The device filter string for the parfor if the parfor is compute follows data conforming. """ + # TODO: check if attribute is dpnp specific if parfor.params is None: return @@ -205,9 +215,11 @@ def _legalize_cfd_parfor_blocks(self, parfor): """Legalize the parfor params based on the compute follows data programming model and usm allocator precedence rule. """ + # this function just sets queue for dpnp and empty attribute for it conforming_device_ty = self._legalize_parfor_params(parfor) # Update the parfor's lowerer attribute + # this sets if we are in dpnp or numpy parfor parfor.lowerer = ParforLowerFactory.get_lowerer(conforming_device_ty) init_block = parfor.init_block @@ -221,6 +233,7 @@ def _legalize_cfd_parfor_blocks(self, parfor): self._legalize_stmt(stmt, block, inparfor=True) def _legalize_expr(self, stmt, lhs, lhsty, parent_block, inparfor=False): + # TODO: rename to infer queue rhs = stmt.value if rhs.op == "call": if self._check_if_dpnp_empty_call(rhs.func, parent_block): @@ -305,6 +318,7 @@ def run(self): @register_pass(mutates_CFG=True, analysis_only=False) class ParforLegalizeCFDPass(FunctionPass): + # TODO: rename to execution queue inferring Pass _name = "parfor_Legalize_CFD_pass" def __init__(self): diff --git a/numba_dpex/core/targets/dpjit_target.py b/numba_dpex/core/targets/dpjit_target.py index ae78cfc190..883a00bc5e 100644 --- a/numba_dpex/core/targets/dpjit_target.py +++ b/numba_dpex/core/targets/dpjit_target.py @@ -11,7 +11,7 @@ from numba.core.codegen import JITCPUCodegen from numba.core.compiler_lock import global_compiler_lock from numba.core.cpu import CPUContext -from numba.core.imputils import Registry, RegistryLoader +from numba.core.imputils import Registry from numba.core.target_extension import CPU, target_registry @@ -34,15 +34,14 @@ def __init__(self, typingctx, target=DPEX_TARGET_NAME): @global_compiler_lock def init(self): - self.is32bit = utils.MACHINE_BITS == 32 - self._internal_codegen = JITCPUCodegen("numba.exec") self.lower_extensions = {} + super().init() + # TODO: initialize nrt once switched to nrt from drt. Most likely we # call it somewhere. Double check. # https://github.com/IntelPython/numba-dpex/issues/1175 # Initialize NRT runtime # rtsys.initialize(self) # noqa: E800 - self.refresh() @cached_property def dpexrt(self): @@ -50,15 +49,11 @@ def dpexrt(self): return DpexRTContext(self) - def refresh(self): - registry = dpex_function_registry - try: - loader = self._registries[registry] - except KeyError: - loader = RegistryLoader(registry) - self._registries[registry] = loader - self.install_registry(registry) - # Also refresh typing context, since @overload declarations can - # affect it. - self.typing_context.refresh() - super().refresh() + def load_additional_registries(self): + """ + Load dpjit-specific registries. + """ + self.install_registry(dpex_function_registry) + + # loading CPU specific registries + super().load_additional_registries() diff --git a/numba_dpex/decorators.py b/numba_dpex/decorators.py index cf9dbe933c..3022c93861 100644 --- a/numba_dpex/decorators.py +++ b/numba_dpex/decorators.py @@ -14,6 +14,7 @@ compile_func_template, ) from numba_dpex.core.pipelines.dpjit_compiler import get_compiler +from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME from .config import USE_MLIR @@ -140,7 +141,7 @@ def _wrapped(pyfunc): def dpjit(*args, **kws): - if "nopython" in kws: + if "nopython" in kws and kws["nopython"] is not True: warnings.warn( "nopython is set for dpjit and is ignored", RuntimeWarning ) @@ -161,14 +162,11 @@ def dpjit(*args, **kws): kws.update({"parallel": True}) kws.update({"pipeline_class": get_compiler(use_mlir)}) - # FIXME: When trying to use dpex's target context, overloads do not work - # properly. We will turn on dpex target once the issue is fixed. - - # kws.update({"_target": "dpex"}) # noqa: E800 + kws.update({"_target": DPEX_TARGET_NAME}) return decorators.jit(*args, **kws) # add it to the decorator registry, this is so e.g. @overload can look up a # JIT function to do the compilation work. -jit_registry[target_registry["dpex"]] = dpjit +jit_registry[target_registry[DPEX_TARGET_NAME]] = dpjit