Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix complex arguments and implement float16 lowering #2403

Open
wants to merge 18 commits into
base: complex
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The Iteration/Expression Tree (IET) hierarchy."""

import abc
import ctypes
import inspect
from functools import cached_property
from collections import OrderedDict, namedtuple
Expand Down Expand Up @@ -1030,6 +1031,8 @@ class Dereference(ExprStmt, Node):
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
* `pointer` is an ArrayObject representing a pointer to a C struct, and
`pointee` is a field in `pointer`.
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
`pointee` is a Symbol representing the dereferenced value.
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
"""

is_Dereference = True
Expand All @@ -1048,13 +1051,21 @@ def functions(self):

@property
def expr_symbols(self):
ret = [self.pointer.indexed]
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret = []
if self.pointer.is_Symbol:
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
"Scalar dereference must have a pointer ctype"
ret.append(self.pointer._C_symbol)
enwask marked this conversation as resolved.
Show resolved Hide resolved
ret.append(self.pointee._C_symbol)
else:
ret.append(self.pointer.indexed)
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
enwask marked this conversation as resolved.
Show resolved Hide resolved
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols
for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret.append(self.pointee._C_symbol)
return tuple(filter_ordered(ret))

@property
Expand Down
10 changes: 8 additions & 2 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from devito.ir.support.space import Backward
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
ListInitializer, ccode, uxreplace)
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types.basic import AbstractFunction, AbstractSymbol, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)

Expand Down Expand Up @@ -208,7 +209,7 @@ def _gen_struct_decl(self, obj, masked=()):
while issubclass(ctype, ctypes._Pointer):
ctype = ctype._type_

if not issubclass(ctype, ctypes.Structure):
if not issubclass(ctype, ctypes.Structure) or issubclass(ctype, NoDeclStruct):
enwask marked this conversation as resolved.
Show resolved Hide resolved
return None
except TypeError:
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
Expand Down Expand Up @@ -453,6 +454,9 @@ def visit_Dereference(self, o):
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
elif a1.is_Symbol:
enwask marked this conversation as resolved.
Show resolved Hide resolved
rvalue = '*%s' % a1.name
lvalue = self._gen_value(a0, 0)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
lvalue = self._gen_value(a0, 0)
Expand Down Expand Up @@ -957,6 +961,7 @@ def default_retval(cls):
Drive the search. Accepted:
- `symbolics`: Collect all AbstractFunction objects, default
- `basics`: Collect all Basic objects
- `scalars`: Collect all AbstractSymbol objects
enwask marked this conversation as resolved.
Show resolved Hide resolved
- `dimensions`: Collect all Dimensions
- `indexeds`: Collect all Indexed objects
- `indexedbases`: Collect all IndexedBase objects
Expand All @@ -977,6 +982,7 @@ def _defines_aliases(n):
rules = {
'symbolics': lambda n: n.functions,
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
'scalars': lambda n: [i for i in n.expr_symbols if isinstance(i, AbstractSymbol)],
enwask marked this conversation as resolved.
Show resolved Hide resolved
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
'indexedbases': lambda n: [i for i in n.expr_symbols
Expand Down
3 changes: 2 additions & 1 deletion devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_complex
from devito.passes.iet.dtypes import lower_complex, lower_scalar_half
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -414,6 +414,7 @@ def place_casts(self, iet, **kwargs):

@iet_pass
def make_langtypes(self, iet):
iet, _ = lower_scalar_half(iet, self.lang, self.sregistry)
enwask marked this conversation as resolved.
Show resolved Hide resolved
iet, metadata = lower_complex(iet, self.lang, self.compiler)
return iet, metadata

Expand Down
70 changes: 52 additions & 18 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,44 @@
import ctypes

from devito.ir import FindSymbols, Uxreplace
from devito.ir.iet.nodes import Dereference
from devito.tools.utils import as_tuple
from devito.types.basic import Symbol

__all__ = ['lower_complex']
__all__ = ['lower_scalar_half', 'lower_complex']


def lower_scalar_half(iet, lang, sregistry):
enwask marked this conversation as resolved.
Show resolved Hide resolved
"""
Lower half float scalars to pointers (special case, since we can't
pass them directly for lack of a ctypes equivalent)
"""
if lang.get('half_types') is None:
return iet, {}

# dtype mappings for float16
half, half_p = lang['half_types']

body = [] # derefs to prepend to the body
body_mapper = {}
enwask marked this conversation as resolved.
Show resolved Hide resolved
params_mapper = {}

for s in FindSymbols('scalars').visit(iet):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just do visit(iet.parameters) directly instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like not, since the actual types I'm mapping (e.g. Constant) aren't nodes they are only caught by FindSymbols if it's as a reference within some other expression, so we need to visit the body. That said I can probably make this marginally more efficient by making a set of parameters beforehand and checking membership that way

if s.dtype != np.float16 or s not in iet.parameters:
continue

ptr = s._rebuild(dtype=half_p)
val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, is_const=True)

params_mapper[s] = ptr
body_mapper[s] = val
body.append(Dereference(val, ptr)) # val = *ptr

body.extend(as_tuple(Uxreplace(body_mapper).visit(iet.body)))
enwask marked this conversation as resolved.
Show resolved Hide resolved
params = Uxreplace(params_mapper).visit(iet.parameters)

iet = iet._rebuild(body=body, parameters=params)
return iet, {}


def lower_complex(iet, lang, compiler):
Expand All @@ -14,30 +50,28 @@ def lower_complex(iet, lang, compiler):
types = {f.dtype for f in FindSymbols().visit(iet)
if not issubclass(f.dtype, ctypes._Pointer)}

if not any(np.issubdtype(d, np.complexfloating) for d in types):
return iet, {}

lib = (lang['header-complex'],)

metadata = {}
if lang.get('complex-namespace') is not None:
metadata['namespaces'] = lang['complex-namespace']
if any(np.issubdtype(d, np.complexfloating) for d in types):
lib = (lang['header-complex'],)

if lang.get('complex-namespace') is not None:
metadata['namespaces'] = lang['complex-namespace']

# Some languges such as c++11 need some extra arithmetic definitions
if lang.get('def-complex'):
dest = compiler.get_jit_dir()
hfile = dest.joinpath('complex_arith.h')
with open(str(hfile), 'w') as ff:
ff.write(str(lang['def-complex']))
lib += (str(hfile),)
# Some languges such as c++11 need some extra arithmetic definitions
if lang.get('def-complex'):
dest = compiler.get_jit_dir()
hfile = dest.joinpath('complex_arith.h')
with open(str(hfile), 'w') as ff:
ff.write(str(lang['def-complex']))
lib += (str(hfile),)

iet = _complex_dtypes(iet, lang)
metadata['includes'] = lib
metadata['includes'] = lib

iet = _lower_dtypes(iet, lang)
return iet, metadata


def _complex_dtypes(iet, lang):
def _lower_dtypes(iet, lang):
"""
Lower dtypes to language specific types
"""
Expand Down
26 changes: 19 additions & 7 deletions devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import ctypes as ct
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.langbase import LangBB
from devito.symbolics.extended_dtypes import (c_complex, c_double_complex,
c_float16, c_float16_p)
from devito.tools.dtypes_lowering import ctypes_vector_mapper


Expand All @@ -19,11 +20,21 @@ class CCDouble(np.complex128):
pass


c_complex = type('_Complex float', (ct.c_double,), {})
c_double_complex = type('_Complex double', (ct.c_longdouble,), {})
class CHalf(np.float16):
pass


class CHalfP(np.float16):
pass


c99_complex = type('_Complex float', (c_complex,), {})
c99_double_complex = type('_Complex double', (c_double_complex,), {})

ctypes_vector_mapper[CCFloat] = c_complex
ctypes_vector_mapper[CCDouble] = c_double_complex
ctypes_vector_mapper[CCFloat] = c99_complex
ctypes_vector_mapper[CCDouble] = c99_double_complex
ctypes_vector_mapper[CHalf] = c_float16
ctypes_vector_mapper[CHalfP] = c_float16_p


class CBB(LangBB):
Expand All @@ -40,9 +51,10 @@ class CBB(LangBB):
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
# Complex and float16
'header-complex': 'complex.h',
'types': {np.complex128: CCDouble, np.complex64: CCFloat},
'types': {np.complex128: CCDouble, np.complex64: CCFloat, np.float16: CHalf},
'half_types': (CHalf, CHalfP),
}


Expand Down
22 changes: 17 additions & 5 deletions devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import ctypes as ct
import numpy as np

from devito.ir import Call, UsingNamespace
from devito.passes.iet.langbase import LangBB
from devito.symbolics.extended_dtypes import (c_complex, c_double_complex,
c_float16, c_float16_p)
from devito.tools.dtypes_lowering import ctypes_vector_mapper

__all__ = ['CXXBB']
Expand Down Expand Up @@ -53,12 +54,21 @@ class CXXCDouble(np.complex128):
pass


cxx_complex = type('std::complex<float>', (ct.c_double,), {})
cxx_double_complex = type('std::complex<double>', (ct.c_longdouble,), {})
class CXXHalf(np.float16):
pass


class CXXHalfP(np.float16):
pass


cxx_complex = type('std::complex<float>', (c_complex,), {})
cxx_double_complex = type('std::complex<double>', (c_double_complex,), {})

ctypes_vector_mapper[CXXCFloat] = cxx_complex
ctypes_vector_mapper[CXXCDouble] = cxx_double_complex
ctypes_vector_mapper[CXXHalf] = c_float16
ctypes_vector_mapper[CXXHalfP] = c_float16_p


class CXXBB(LangBB):
Expand All @@ -75,9 +85,11 @@ class CXXBB(LangBB):
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
# Complex and float16
'header-complex': 'complex',
'complex-namespace': [UsingNamespace('std::complex_literals')],
'def-complex': std_arith,
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat},
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat,
np.float16: CXXHalf},
'half_types': (CXXHalf, CXXHalfP),
enwask marked this conversation as resolved.
Show resolved Hide resolved
}
51 changes: 50 additions & 1 deletion devito/symbolics/extended_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import ctypes as ct
enwask marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np

from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit
from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa
int2, int3, int4)

__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa
__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT',
'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex',
'c_float16', 'c_float16_p']


limits_mapper = {
Expand All @@ -15,6 +18,52 @@
}


class NoDeclStruct(ct.Structure):
# ctypes.Structure that does not generate a struct definition
enwask marked this conversation as resolved.
Show resolved Hide resolved
pass


class c_complex(NoDeclStruct):
# Structure for passing complex float to C/C++
_fields_ = [('real', ct.c_float), ('imag', ct.c_float)]

@classmethod
def from_param(cls, val):
return cls(val.real, val.imag)


class c_double_complex(NoDeclStruct):
# Structure for passing complex double to C/C++
_fields_ = [('real', ct.c_double), ('imag', ct.c_double)]

@classmethod
def from_param(cls, val):
return cls(val.real, val.imag)


class _c_half(ct.c_uint16):
# Ctype for non-scalar half floats
@classmethod
def from_param(cls, val):
return cls(np.float16(val).view(np.uint16))


c_float16 = type('_Float16', (_c_half,), {})
enwask marked this conversation as resolved.
Show resolved Hide resolved


class _c_half_p(ct.POINTER(c_float16)):
enwask marked this conversation as resolved.
Show resolved Hide resolved
# Ctype for half scalars; we can't directly pass _Float16 values so
# we use a pointer and dereference (see `passes.iet.dtypes`)
@classmethod
enwask marked this conversation as resolved.
Show resolved Hide resolved
def from_param(cls, val):
arr = np.array(val, dtype=np.float16)
return arr.ctypes.data_as(cls)


# ctypes directly parses class dict; can't inherit the _type_ attribute
c_float16_p = type('_Float16 *', (_c_half_p,), {'_type_': c_float16})
enwask marked this conversation as resolved.
Show resolved Hide resolved


class CustomType(ReservedWord):
pass

Expand Down
Loading