Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix complex arguments and implement float16 lowering #2403

Open
wants to merge 18 commits into
base: complex
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from devito.logger import logger
from devito.parameters import configuration
from devito.tools import dtype_to_ctype, is_integer
from devito.tools import is_integer, dtype_alloc_ctype

__all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY',
'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD',
Expand Down Expand Up @@ -92,12 +92,8 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
# For complex number, allocate double the size of its real/imaginary part
alloc_dtype = dtype(0).real.__class__
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1

ctype, c_scale = dtype_alloc_ctype(dtype)
datasize = int(reduce(mul, shape) * c_scale)
ctype = dtype_to_ctype(alloc_dtype)

# Add padding, if any
try:
Expand Down
18 changes: 13 additions & 5 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The Iteration/Expression Tree (IET) hierarchy."""

import abc
import ctypes
import inspect
from functools import cached_property
from collections import OrderedDict, namedtuple
Expand Down Expand Up @@ -1030,6 +1031,8 @@ class Dereference(ExprStmt, Node):
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
* `pointer` is an ArrayObject representing a pointer to a C struct, and
`pointee` is a field in `pointer`.
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
`pointee` is a Symbol representing the dereferenced value.
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
"""

is_Dereference = True
Expand All @@ -1048,13 +1051,18 @@ def functions(self):

@property
def expr_symbols(self):
ret = [self.pointer.indexed]
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:]))
ret = []
if self.pointer.is_Symbol:
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
"Scalar dereference must have a pointer ctype"
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
elif self.pointer.is_PointerArray or self.pointer.is_TempFunction:
ret.extend([self.pointer.indexed, self.pointee.indexed])
ret.extend(flatten(i.free_symbols
for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret.append(self.pointee._C_symbol)
ret.extend([self.pointer.indexed, self.pointee._C_symbol])
return tuple(filter_ordered(ret))

@property
Expand Down
10 changes: 7 additions & 3 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from devito.ir.support.space import Backward
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
ListInitializer, ccode, uxreplace)
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types.basic import AbstractFunction, AbstractSymbol, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)

Expand Down Expand Up @@ -208,7 +209,7 @@ def _gen_struct_decl(self, obj, masked=()):
while issubclass(ctype, ctypes._Pointer):
ctype = ctype._type_

if not issubclass(ctype, ctypes.Structure):
if not issubclass(ctype, ctypes.Structure) or issubclass(ctype, NoDeclStruct):
enwask marked this conversation as resolved.
Show resolved Hide resolved
return None
except TypeError:
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
Expand Down Expand Up @@ -454,7 +455,7 @@ def visit_Dereference(self, o):
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
rvalue = '*%s' % a1.name if a1.is_Symbol else '%s->%s' % (a1.name, a0._C_name)
enwask marked this conversation as resolved.
Show resolved Hide resolved
lvalue = self._gen_value(a0, 0)
return c.Initializer(lvalue, rvalue)

Expand Down Expand Up @@ -957,6 +958,7 @@ def default_retval(cls):
Drive the search. Accepted:
- `symbolics`: Collect all AbstractFunction objects, default
- `basics`: Collect all Basic objects
- `abstractsymbols`: Collect all AbstractSymbol objects
- `dimensions`: Collect all Dimensions
- `indexeds`: Collect all Indexed objects
- `indexedbases`: Collect all IndexedBase objects
Expand All @@ -977,6 +979,8 @@ def _defines_aliases(n):
rules = {
'symbolics': lambda n: n.functions,
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
'abstractsymbols': lambda n: [i for i in n.expr_symbols
if isinstance(i, AbstractSymbol)],
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
'indexedbases': lambda n: [i for i in n.expr_symbols
Expand Down
10 changes: 10 additions & 0 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, is_on_device)
from devito.passes.iet.langbase import LangBB
from devito.symbolics import estimate_cost, subs_op_args
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
flatten, filter_sorted, frozendict, is_integer,
split, timed_pass, timed_region, contains_val)
from devito.tools.dtypes_lowering import ctypes_vector_mapper
from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer,
disk_layer)

Expand Down Expand Up @@ -264,6 +266,9 @@ def _lower(cls, expressions, **kwargs):
# expression for which a partial or complete lowering is desired
kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs)

# Load language-specific types into the global dtype->ctype mapper
cls._load_dtype_mappings(**kwargs)

# [Eq] -> [LoweredEq]
expressions = cls._lower_exprs(expressions, **kwargs)

Expand All @@ -285,6 +290,11 @@ def _lower(cls, expressions, **kwargs):
def _rcompile_wrapper(cls, **kwargs0):
raise NotImplementedError

@classmethod
def _load_dtype_mappings(cls, **kwargs):
lang: type[LangBB] = cls._Target.DataManager.lang
ctypes_vector_mapper.update(lang.mapper.get('types', {}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that's a bit tricky because this updates a global ctypes_vector_mapper which might lead to odd behavior building multiple operators with different languages.
Do you know where it's called and needs those types ? I.e can the mapper be "local" to the operator and passed there?


@classmethod
def _initialize_state(cls, **kwargs):
return {}
Expand Down
10 changes: 5 additions & 5 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_complex
from devito.passes.iet.dtypes import lower_dtypes
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -413,8 +413,8 @@ def place_casts(self, iet, **kwargs):
return iet, {}

@iet_pass
def make_langtypes(self, iet):
iet, metadata = lower_complex(iet, self.lang, self.compiler)
def lower_dtypes(self, iet):
iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry)
return iet, metadata

def process(self, graph):
Expand All @@ -423,7 +423,7 @@ def process(self, graph):
"""
self.place_definitions(graph, globs=set())
self.place_casts(graph)
self.make_langtypes(graph)
self.lower_dtypes(graph)


class DeviceAwareDataManager(DataManager):
Expand Down Expand Up @@ -573,7 +573,7 @@ def process(self, graph):
self.place_devptr(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)
self.make_langtypes(graph)
self.lower_dtypes(graph)


def make_zero_init(obj):
Expand Down
76 changes: 51 additions & 25 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,69 @@
import numpy as np
import ctypes
import numpy as np

from devito.arch.compiler import Compiler
from devito.ir import Callable, Dereference, FindSymbols, Node, SymbolRegistry, Uxreplace
from devito.passes.iet.langbase import LangBB
from devito.symbolics.extended_dtypes import Float16P
from devito.tools import as_list
from devito.types.basic import AbstractSymbol, Basic, Symbol

__all__ = ['lower_dtypes']


def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler,
sregistry: SymbolRegistry) -> tuple[Callable, dict]:
"""
Lowers float16 scalar types to pointers since we can't directly pass their
value. Also includes headers for complex arithmetic if needed.
"""

iet, metadata = _complex_includes(iet, lang, compiler)

# Lower float16 parameters to pointers and dereference
prefix: list[Node] = []
params_mapper: dict[AbstractSymbol, AbstractSymbol] = {}
body_mapper: dict[AbstractSymbol, Symbol] = {}

params_set = set(iet.parameters)
s: AbstractSymbol
for s in FindSymbols('abstractsymbols').visit(iet):
if s.dtype != np.float16 or s not in params_set:
continue

# Replace the parameter with a pointer; replace occurences in the IET
# body with dereferenced symbol (using the original symbol's dtype)
ptr: AbstractSymbol = s._rebuild(dtype=Float16P, is_const=True)
val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype,
is_const=s.is_const)

from devito.ir import FindSymbols, Uxreplace
params_mapper[s], body_mapper[s] = ptr, val
prefix.append(Dereference(val, ptr)) # val = *ptr

__all__ = ['lower_complex']
# Apply the replacements
prefix.extend(as_list(Uxreplace(body_mapper).visit(iet.body)))
params: tuple[Basic] = Uxreplace(params_mapper).visit(iet.parameters)

iet = iet._rebuild(body=prefix, parameters=params)
return iet, metadata


def lower_complex(iet, lang, compiler):
def _complex_includes(iet: Callable, lang: type[LangBB],
compiler: Compiler) -> tuple[Callable, dict]:
"""
Add headers for complex arithmetic
Includes complex arithmetic headers for the given language, if needed.
"""
# Check if there is complex numbers that always take dtype precedence

# Check if there are complex numbers that always take dtype precedence
types = {f.dtype for f in FindSymbols().visit(iet)
if not issubclass(f.dtype, ctypes._Pointer)}

if not any(np.issubdtype(d, np.complexfloating) for d in types):
return iet, {}

metadata = {}
lib = (lang['header-complex'],)

metadata = {}
if lang.get('complex-namespace') is not None:
metadata['namespaces'] = lang['complex-namespace']

Expand All @@ -31,24 +75,6 @@ def lower_complex(iet, lang, compiler):
ff.write(str(lang['def-complex']))
lib += (str(hfile),)

iet = _complex_dtypes(iet, lang)
metadata['includes'] = lib

return iet, metadata


def _complex_dtypes(iet, lang):
"""
Lower dtypes to language specific types
"""
mapper = {}

for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
if s.dtype in lang['types']:
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

body = Uxreplace(mapper).visit(iet.body)
params = Uxreplace(mapper).visit(iet.parameters)
iet = iet._rebuild(body=body, parameters=params)

return iet
29 changes: 12 additions & 17 deletions devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
import ctypes as ct
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.langbase import LangBB
from devito.tools.dtypes_lowering import ctypes_vector_mapper
from devito.symbolics.extended_dtypes import (Float16P, c_complex, c_double_complex,
c_half, c_half_p)


__all__ = ['CBB', 'CDataManager', 'COrchestrator']
__all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p']


class CCFloat(np.complex64):
pass
c99_complex = type('_Complex float', (c_complex,), {})
c99_double_complex = type('_Complex double', (c_double_complex,), {})


class CCDouble(np.complex128):
pass


c_complex = type('_Complex float', (ct.c_double,), {})
c_double_complex = type('_Complex double', (ct.c_longdouble,), {})

ctypes_vector_mapper[CCFloat] = c_complex
ctypes_vector_mapper[CCDouble] = c_double_complex
c_float16 = type('_Float16', (c_half,), {})
c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16})


class CBB(LangBB):
Expand All @@ -40,9 +32,12 @@ class CBB(LangBB):
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
# Complex and float16
'header-complex': 'complex.h',
'types': {np.complex128: CCDouble, np.complex64: CCFloat},
'types': {np.complex128: c99_double_complex,
np.complex64: c99_complex,
np.float16: c_float16,
Float16P: c_float16_p}
}


Expand Down
27 changes: 9 additions & 18 deletions devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import ctypes as ct
import numpy as np

from devito.ir import Call, UsingNamespace
from devito.passes.iet.langbase import LangBB
from devito.tools.dtypes_lowering import ctypes_vector_mapper
from devito.passes.iet.languages.C import c_float16, c_float16_p
from devito.symbolics.extended_dtypes import Float16P, c_complex, c_double_complex

__all__ = ['CXXBB']

Expand Down Expand Up @@ -45,20 +45,8 @@
"""


class CXXCFloat(np.complex64):
pass


class CXXCDouble(np.complex128):
pass


cxx_complex = type('std::complex<float>', (ct.c_double,), {})
cxx_double_complex = type('std::complex<double>', (ct.c_longdouble,), {})


ctypes_vector_mapper[CXXCFloat] = cxx_complex
ctypes_vector_mapper[CXXCDouble] = cxx_double_complex
cxx_complex = type('std::complex<float>', (c_complex,), {})
cxx_double_complex = type('std::complex<double>', (c_double_complex,), {})


class CXXBB(LangBB):
Expand All @@ -75,9 +63,12 @@ class CXXBB(LangBB):
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
# Complex and float16
'header-complex': 'complex',
'complex-namespace': [UsingNamespace('std::complex_literals')],
'def-complex': std_arith,
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat},
"types": {np.complex128: cxx_double_complex,
np.complex64: cxx_complex,
np.float16: c_float16,
Float16P: c_float16_p}
}
Loading