Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix complex arguments and implement float16 lowering #2403

Open
wants to merge 18 commits into
base: complex
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from devito.logger import logger
from devito.parameters import configuration
from devito.tools import dtype_to_ctype, is_integer
from devito.tools import is_integer, dtype_alloc_ctype

__all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY',
'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD',
Expand Down Expand Up @@ -92,12 +92,8 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
# For complex number, allocate double the size of its real/imaginary part
alloc_dtype = dtype(0).real.__class__
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1

ctype, c_scale = dtype_alloc_ctype(dtype)
datasize = int(reduce(mul, shape) * c_scale)
ctype = dtype_to_ctype(alloc_dtype)

# Add padding, if any
try:
Expand Down
22 changes: 16 additions & 6 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The Iteration/Expression Tree (IET) hierarchy."""

import abc
import ctypes
import inspect
from functools import cached_property
from collections import OrderedDict, namedtuple
Expand Down Expand Up @@ -1030,6 +1031,8 @@ class Dereference(ExprStmt, Node):
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
* `pointer` is an ArrayObject representing a pointer to a C struct, and
`pointee` is a field in `pointer`.
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
`pointee` is a Symbol representing the dereferenced value.
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
"""

is_Dereference = True
Expand All @@ -1048,13 +1051,20 @@ def functions(self):

@property
def expr_symbols(self):
ret = [self.pointer.indexed]
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
ret = []
if self.pointer.is_Symbol:
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
"Scalar dereference must have a pointer ctype"
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
else:
ret.append(self.pointee._C_symbol)
ret.append(self.pointer.indexed)
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
enwask marked this conversation as resolved.
Show resolved Hide resolved
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols
for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret.append(self.pointee._C_symbol)
return tuple(filter_ordered(ret))

@property
Expand Down
10 changes: 8 additions & 2 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from devito.ir.support.space import Backward
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
ListInitializer, ccode, uxreplace)
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types.basic import AbstractFunction, AbstractSymbol, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)

Expand Down Expand Up @@ -208,7 +209,7 @@ def _gen_struct_decl(self, obj, masked=()):
while issubclass(ctype, ctypes._Pointer):
ctype = ctype._type_

if not issubclass(ctype, ctypes.Structure):
if not issubclass(ctype, ctypes.Structure) or issubclass(ctype, NoDeclStruct):
enwask marked this conversation as resolved.
Show resolved Hide resolved
return None
except TypeError:
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
Expand Down Expand Up @@ -453,6 +454,9 @@ def visit_Dereference(self, o):
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
elif a1.is_Symbol:
enwask marked this conversation as resolved.
Show resolved Hide resolved
rvalue = '*%s' % a1.name
lvalue = self._gen_value(a0, 0)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
lvalue = self._gen_value(a0, 0)
Expand Down Expand Up @@ -957,6 +961,7 @@ def default_retval(cls):
Drive the search. Accepted:
- `symbolics`: Collect all AbstractFunction objects, default
- `basics`: Collect all Basic objects
- `scalars`: Collect all AbstractSymbol objects
enwask marked this conversation as resolved.
Show resolved Hide resolved
- `dimensions`: Collect all Dimensions
- `indexeds`: Collect all Indexed objects
- `indexedbases`: Collect all IndexedBase objects
Expand All @@ -977,6 +982,7 @@ def _defines_aliases(n):
rules = {
'symbolics': lambda n: n.functions,
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
'scalars': lambda n: [i for i in n.expr_symbols if isinstance(i, AbstractSymbol)],
enwask marked this conversation as resolved.
Show resolved Hide resolved
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
'indexedbases': lambda n: [i for i in n.expr_symbols
Expand Down
10 changes: 10 additions & 0 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, is_on_device)
from devito.passes.iet.langbase import LangBB
from devito.symbolics import estimate_cost, subs_op_args
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
flatten, filter_sorted, frozendict, is_integer,
split, timed_pass, timed_region, contains_val)
from devito.tools.dtypes_lowering import ctypes_vector_mapper
from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer,
disk_layer)

Expand Down Expand Up @@ -264,6 +266,9 @@ def _lower(cls, expressions, **kwargs):
# expression for which a partial or complete lowering is desired
kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs)

# Load language-specific types into the global dtype->ctype mapper
cls._load_dtype_mappings(**kwargs)

# [Eq] -> [LoweredEq]
expressions = cls._lower_exprs(expressions, **kwargs)

Expand All @@ -285,6 +290,11 @@ def _lower(cls, expressions, **kwargs):
def _rcompile_wrapper(cls, **kwargs0):
raise NotImplementedError

@classmethod
def _load_dtype_mappings(cls, **kwargs):
lang: type[LangBB] = cls._Target.DataManager.lang
ctypes_vector_mapper.update(lang.mapper.get('types', {}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that's a bit tricky because this updates a global ctypes_vector_mapper which might lead to odd behavior building multiple operators with different languages.
Do you know where it's called and needs those types ? I.e can the mapper be "local" to the operator and passed there?


@classmethod
def _initialize_state(cls, **kwargs):
return {}
Expand Down
10 changes: 5 additions & 5 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_complex
from devito.passes.iet.dtypes import lower_dtypes
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -413,8 +413,8 @@ def place_casts(self, iet, **kwargs):
return iet, {}

@iet_pass
def make_langtypes(self, iet):
iet, metadata = lower_complex(iet, self.lang, self.compiler)
def lower_dtypes(self, iet):
iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry)
return iet, metadata

def process(self, graph):
Expand All @@ -423,7 +423,7 @@ def process(self, graph):
"""
self.place_definitions(graph, globs=set())
self.place_casts(graph)
self.make_langtypes(graph)
self.lower_dtypes(graph)


class DeviceAwareDataManager(DataManager):
Expand Down Expand Up @@ -573,7 +573,7 @@ def process(self, graph):
self.place_devptr(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)
self.make_langtypes(graph)
self.lower_dtypes(graph)


def make_zero_init(obj):
Expand Down
68 changes: 45 additions & 23 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,54 @@
import numpy as np
import ctypes

from devito.ir import FindSymbols, Uxreplace
from devito.ir import FindSymbols
from devito.ir.iet.nodes import Dereference
from devito.ir.iet.visitors import Uxreplace
from devito.symbolics.extended_dtypes import Float16P
from devito.tools.utils import as_list
from devito.types.basic import Symbol

__all__ = ['lower_complex']
__all__ = ['lower_dtypes']


def lower_complex(iet, lang, compiler):
def lower_dtypes(iet, lang, compiler, sregistry):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can this not be an @iet_pass and still work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called from an @iet_pass in definitions here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah OK , I maybe have told Mathias already, but imho that thing goes straight into operator/operator.py::Operator::_lower_iet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, I can make this change

"""
Add headers for complex arithmetic
Lowers float16 scalar types to pointers since we can't directly pass their
value. Also includes headers for complex arithmetic if needed.
"""

iet, metadata = _complex_includes(iet, lang, compiler)

# Lower float16 parameters to pointers and dereference
body_prefix = []
enwask marked this conversation as resolved.
Show resolved Hide resolved
body_mapper = {}
enwask marked this conversation as resolved.
Show resolved Hide resolved
params_mapper = {}

# Lower scalar float16s to pointers and dereference them
for s in FindSymbols('scalars').visit(iet):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just do visit(iet.parameters) directly instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like not, since the actual types I'm mapping (e.g. Constant) aren't nodes they are only caught by FindSymbols if it's as a reference within some other expression, so we need to visit the body. That said I can probably make this marginally more efficient by making a set of parameters beforehand and checking membership that way

if not np.issubdtype(s.dtype, np.float16) or s not in iet.parameters:
continue

# Replace the parameter with a pointer; replace occurences in the IET
# body with a dereference (using the original symbol's dtype)
ptr = s._rebuild(dtype=Float16P, is_const=True)
val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype,
is_const=s.is_const)

params_mapper[s], body_mapper[s] = ptr, val
body_prefix.append(Dereference(val, ptr)) # val = *ptr

# Apply the replacements
body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body))
enwask marked this conversation as resolved.
Show resolved Hide resolved
params = Uxreplace(params_mapper).visit(iet.parameters)

iet = iet._rebuild(body=body, parameters=params)
return iet, metadata


def _complex_includes(iet, lang, compiler):
"""
Include complex arithmetic headers for the given language, if needed.
"""
# Check if there is complex numbers that always take dtype precedence
types = {f.dtype for f in FindSymbols().visit(iet)
Expand All @@ -17,9 +57,9 @@ def lower_complex(iet, lang, compiler):
if not any(np.issubdtype(d, np.complexfloating) for d in types):
return iet, {}

metadata = {}
lib = (lang['header-complex'],)

metadata = {}
if lang.get('complex-namespace') is not None:
metadata['namespaces'] = lang['complex-namespace']

Expand All @@ -31,24 +71,6 @@ def lower_complex(iet, lang, compiler):
ff.write(str(lang['def-complex']))
lib += (str(hfile),)

iet = _complex_dtypes(iet, lang)
metadata['includes'] = lib

return iet, metadata


def _complex_dtypes(iet, lang):
"""
Lower dtypes to language specific types
"""
mapper = {}

for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
if s.dtype in lang['types']:
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

body = Uxreplace(mapper).visit(iet.body)
params = Uxreplace(mapper).visit(iet.parameters)
iet = iet._rebuild(body=body, parameters=params)

return iet
29 changes: 12 additions & 17 deletions devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
import ctypes as ct
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.langbase import LangBB
from devito.tools.dtypes_lowering import ctypes_vector_mapper
from devito.symbolics.extended_dtypes import (Float16P, c_complex, c_double_complex,
c_half, c_half_p)


__all__ = ['CBB', 'CDataManager', 'COrchestrator']
__all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p']


class CCFloat(np.complex64):
pass
c99_complex = type('_Complex float', (c_complex,), {})
c99_double_complex = type('_Complex double', (c_double_complex,), {})


class CCDouble(np.complex128):
pass


c_complex = type('_Complex float', (ct.c_double,), {})
c_double_complex = type('_Complex double', (ct.c_longdouble,), {})

ctypes_vector_mapper[CCFloat] = c_complex
ctypes_vector_mapper[CCDouble] = c_double_complex
c_float16 = type('_Float16', (c_half,), {})
c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16})


class CBB(LangBB):
Expand All @@ -40,9 +32,12 @@ class CBB(LangBB):
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
# Complex and float16
'header-complex': 'complex.h',
'types': {np.complex128: CCDouble, np.complex64: CCFloat},
'types': {np.complex128: c99_double_complex,
np.complex64: c99_complex,
np.float16: c_float16,
Float16P: c_float16_p}
}


Expand Down
27 changes: 9 additions & 18 deletions devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import ctypes as ct
import numpy as np

from devito.ir import Call, UsingNamespace
from devito.passes.iet.langbase import LangBB
from devito.tools.dtypes_lowering import ctypes_vector_mapper
from devito.passes.iet.languages.C import c_float16, c_float16_p
from devito.symbolics.extended_dtypes import Float16P, c_complex, c_double_complex

__all__ = ['CXXBB']

Expand Down Expand Up @@ -45,20 +45,8 @@
"""


class CXXCFloat(np.complex64):
pass


class CXXCDouble(np.complex128):
pass


cxx_complex = type('std::complex<float>', (ct.c_double,), {})
cxx_double_complex = type('std::complex<double>', (ct.c_longdouble,), {})


ctypes_vector_mapper[CXXCFloat] = cxx_complex
ctypes_vector_mapper[CXXCDouble] = cxx_double_complex
cxx_complex = type('std::complex<float>', (c_complex,), {})
cxx_double_complex = type('std::complex<double>', (c_double_complex,), {})


class CXXBB(LangBB):
Expand All @@ -75,9 +63,12 @@ class CXXBB(LangBB):
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
# Complex and float16
'header-complex': 'complex',
'complex-namespace': [UsingNamespace('std::complex_literals')],
'def-complex': std_arith,
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat},
"types": {np.complex128: cxx_double_complex,
np.complex64: cxx_complex,
np.float16: c_float16,
Float16P: c_float16_p}
}
Loading