Skip to content

Commit

Permalink
Remove dtypes lowering from IET layer
Browse files Browse the repository at this point in the history
  • Loading branch information
enwask committed Jul 16, 2024
1 parent d3169d0 commit 493c1e8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 53 deletions.
10 changes: 5 additions & 5 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_dtypes
from devito.passes.iet.dtypes import include_complex
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -413,8 +413,8 @@ def place_casts(self, iet, **kwargs):
return iet, {}

@iet_pass
def make_langtypes(self, iet):
iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry)
def include_complex(self, iet):
iet, metadata = include_complex(iet, self.lang, self.compiler)
return iet, metadata

def process(self, graph):
Expand All @@ -423,7 +423,7 @@ def process(self, graph):
"""
self.place_definitions(graph, globs=set())
self.place_casts(graph)
self.make_langtypes(graph)
self.include_complex(graph)


class DeviceAwareDataManager(DataManager):
Expand Down Expand Up @@ -573,7 +573,7 @@ def process(self, graph):
self.place_devptr(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)
self.make_langtypes(graph)
self.include_complex(graph)


def make_zero_init(obj):
Expand Down
54 changes: 6 additions & 48 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,21 @@
import numpy as np
import ctypes

from devito.ir import FindSymbols, Uxreplace
from devito.ir.iet.nodes import Dereference
from devito.tools.utils import as_list
from devito.types.basic import Symbol
from devito.ir import FindSymbols

__all__ = ['lower_dtypes']
__all__ = ['include_complex']


def lower_dtypes(iet, lang, compiler, sregistry):
def include_complex(iet, lang, compiler):
"""
Lower language-specific dtypes and add headers for complex arithmetic
"""
# Include complex headers if needed (before we replace complex dtypes)
metadata = _complex_includes(iet, lang, compiler)

body_prefix = [] # Derefs to prepend to the body
body_mapper = {}
params_mapper = {}

# Lower scalar float16s to pointers and dereference them
if lang.get('half_types') is not None:
half, half_p = lang['half_types'] # dtype mappings for half float

for s in FindSymbols('scalars').visit(iet):
if s.dtype != np.float16 or s not in iet.parameters:
continue

ptr = s._rebuild(dtype=half_p, is_const=True)
val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half,
is_const=s.is_const)

params_mapper[s], body_mapper[s] = ptr, val
body_prefix.append(Dereference(val, ptr)) # val = *ptr

# Lower remaining language-specific dtypes
for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
if s.dtype in lang['types'] and s not in params_mapper:
body_mapper[s] = params_mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

# Apply the dtype replacements
body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body))
params = Uxreplace(params_mapper).visit(iet.parameters)

iet = iet._rebuild(body=body, parameters=params)
return iet, metadata


def _complex_includes(iet, lang, compiler):
"""
Add headers for complex arithmetic
Include complex arithmetic headers for the given language, if needed.
"""
# Check if there is complex numbers that always take dtype precedence
types = {f.dtype for f in FindSymbols().visit(iet)
if not issubclass(f.dtype, ctypes._Pointer)}

if not any(np.issubdtype(d, np.complexfloating) for d in types):
return {}
return iet, {}

metadata = {}
lib = (lang['header-complex'],)
Expand All @@ -75,4 +33,4 @@ def _complex_includes(iet, lang, compiler):

metadata['includes'] = lib

return metadata
return iet, metadata

0 comments on commit 493c1e8

Please sign in to comment.