Skip to content

Commit

Permalink
Use dpjit target context
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jan 4, 2024
1 parent 429d87c commit 762b9a7
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 33 deletions.
63 changes: 53 additions & 10 deletions numba_dpex/core/dpjit_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,44 @@
#
# SPDX-License-Identifier: Apache-2.0

from numba.core import compiler, dispatcher
from numba.core.target_extension import dispatcher_registry, target_registry
from numba.core import dispatcher, errors
from numba.core.target_extension import (
dispatcher_registry,
target_override,
target_registry,
)

from numba_dpex import numba_sem_version
from numba_dpex.core.pipelines import dpjit_compiler
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME

from .descriptor import dpex_target


class _DpjitCompiler(dispatcher._FunctionCompiler):
"""A special compiler class used to compile numba_dpex.dpjit decorated
functions.
"""

def _compile_cached(self, args, return_type):
# follows the same logic as original one, but triggers _compile_core
# with dpex target overload.
key = tuple(args), return_type
try:
return False, self._failed_cache[key]
except KeyError:
pass

try:
with target_override(DPEX_TARGET_NAME):
retval = self._compile_core(args, return_type)
except errors.TypingError as e:
self._failed_cache[key] = e
return False, e
else:
return True, retval


class DpjitDispatcher(dispatcher.Dispatcher):
"""A dpex.djit-specific dispatcher.
Expand All @@ -26,16 +56,29 @@ def __init__(
py_func,
locals={},
targetoptions={},
impl_kind="direct",
pipeline_class=compiler.Compiler,
pipeline_class=dpjit_compiler.DpjitCompiler,
):
dispatcher.Dispatcher.__init__(
self,
if numba_sem_version < (0, 59, 0):
super().__init__(
py_func=py_func,
locals=locals,
impl_kind="direct",
targetoptions=targetoptions,
pipeline_class=pipeline_class,
)
else:
super().__init__(
py_func=py_func,
locals=locals,
targetoptions=targetoptions,
pipeline_class=pipeline_class,
)
self._compiler = _DpjitCompiler(
py_func,
locals=locals,
targetoptions=targetoptions,
impl_kind=impl_kind,
pipeline_class=pipeline_class,
self.targetdescr,
targetoptions,
locals,
pipeline_class,
)


Expand Down
16 changes: 15 additions & 1 deletion numba_dpex/core/passes/parfor_legalize_cfd_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class ParforLegalizeCFDPassImpl:
"""

# TODO: fix point algorithm implementation

inputUsmTypeStrToInt = {"device": 3, "shared": 2, "host": 1}
inputUsmTypeIntToStr = {3: "device", 2: "shared", 1: "host"}

Expand All @@ -59,7 +61,14 @@ def _check_if_dpnp_empty_call(self, call_stmt, block):
and isinstance(func_def.value, ir.Expr)
and func_def.value.op == "getattr"
):
raise AssertionError
# TODO: write unit test that this check passes for the instruction
# generated by dpnp.empty() after type inferring pass
# Possible implementation:
# 1. dpjit with dpnp.empty()
# 2. generate Numba IR after type inference for it by using dummy
# pipeline
# 3. check that we have assertion making place
return False

module_name = block.find_variable_assignment(
func_def.value.list_vars()[0].name
Expand Down Expand Up @@ -162,6 +171,7 @@ def _legalize_parfor_params(self, parfor):
str: The device filter string for the parfor if the parfor is
compute follows data conforming.
"""
# TODO: check if attribute is dpnp specific
if parfor.params is None:
return

Expand Down Expand Up @@ -205,9 +215,11 @@ def _legalize_cfd_parfor_blocks(self, parfor):
"""Legalize the parfor params based on the compute follows data
programming model and usm allocator precedence rule.
"""
# this function just sets queue for dpnp and empty attribute for it
conforming_device_ty = self._legalize_parfor_params(parfor)

# Update the parfor's lowerer attribute
# this sets if we are in dpnp or numpy parfor
parfor.lowerer = ParforLowerFactory.get_lowerer(conforming_device_ty)

init_block = parfor.init_block
Expand All @@ -221,6 +233,7 @@ def _legalize_cfd_parfor_blocks(self, parfor):
self._legalize_stmt(stmt, block, inparfor=True)

def _legalize_expr(self, stmt, lhs, lhsty, parent_block, inparfor=False):
# TODO: rename to infer queue
rhs = stmt.value
if rhs.op == "call":
if self._check_if_dpnp_empty_call(rhs.func, parent_block):
Expand Down Expand Up @@ -305,6 +318,7 @@ def run(self):

@register_pass(mutates_CFG=True, analysis_only=False)
class ParforLegalizeCFDPass(FunctionPass):
# TODO: rename to execution queue inferring Pass
_name = "parfor_Legalize_CFD_pass"

def __init__(self):
Expand Down
27 changes: 11 additions & 16 deletions numba_dpex/core/targets/dpjit_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numba.core.codegen import JITCPUCodegen
from numba.core.compiler_lock import global_compiler_lock
from numba.core.cpu import CPUContext
from numba.core.imputils import Registry, RegistryLoader
from numba.core.imputils import Registry
from numba.core.target_extension import CPU, target_registry


Expand All @@ -34,31 +34,26 @@ def __init__(self, typingctx, target=DPEX_TARGET_NAME):

@global_compiler_lock
def init(self):
self.is32bit = utils.MACHINE_BITS == 32
self._internal_codegen = JITCPUCodegen("numba.exec")
self.lower_extensions = {}
super().init()

# TODO: initialize nrt once switched to nrt from drt. Most likely we
# call it somewhere. Double check.
# https://github.com/IntelPython/numba-dpex/issues/1175
# Initialize NRT runtime
# rtsys.initialize(self) # noqa: E800
self.refresh()

@cached_property
def dpexrt(self):
from numba_dpex.core.runtime.context import DpexRTContext

return DpexRTContext(self)

def refresh(self):
registry = dpex_function_registry
try:
loader = self._registries[registry]
except KeyError:
loader = RegistryLoader(registry)
self._registries[registry] = loader
self.install_registry(registry)
# Also refresh typing context, since @overload declarations can
# affect it.
self.typing_context.refresh()
super().refresh()
def load_additional_registries(self):
"""
Load dpjit-specific registries.
"""
self.install_registry(dpex_function_registry)

# loading CPU specific registries
super().load_additional_registries()
10 changes: 4 additions & 6 deletions numba_dpex/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
compile_func_template,
)
from numba_dpex.core.pipelines.dpjit_compiler import get_compiler
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME

from .config import USE_MLIR

Expand Down Expand Up @@ -140,7 +141,7 @@ def _wrapped(pyfunc):


def dpjit(*args, **kws):
if "nopython" in kws:
if "nopython" in kws and kws["nopython"] is not True:
warnings.warn(
"nopython is set for dpjit and is ignored", RuntimeWarning
)
Expand All @@ -161,14 +162,11 @@ def dpjit(*args, **kws):
kws.update({"parallel": True})
kws.update({"pipeline_class": get_compiler(use_mlir)})

# FIXME: When trying to use dpex's target context, overloads do not work
# properly. We will turn on dpex target once the issue is fixed.

# kws.update({"_target": "dpex"}) # noqa: E800
kws.update({"_target": DPEX_TARGET_NAME})

return decorators.jit(*args, **kws)


# add it to the decorator registry, this is so e.g. @overload can look up a
# JIT function to do the compilation work.
jit_registry[target_registry["dpex"]] = dpjit
jit_registry[target_registry[DPEX_TARGET_NAME]] = dpjit

0 comments on commit 762b9a7

Please sign in to comment.