Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix complex arguments and implement float16 lowering #2403

Open
wants to merge 18 commits into
base: complex
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from devito.logger import logger
from devito.parameters import configuration
from devito.tools import dtype_to_ctype, is_integer
from devito.tools import is_integer
from devito.tools.dtypes_lowering import dtype_alloc_ctype
enwask marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY',
'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD',
Expand Down Expand Up @@ -92,12 +93,8 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
# For complex number, allocate double the size of its real/imaginary part
alloc_dtype = dtype(0).real.__class__
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1

ctype, c_scale = dtype_alloc_ctype(dtype)
datasize = int(reduce(mul, shape) * c_scale)
ctype = dtype_to_ctype(alloc_dtype)

# Add padding, if any
try:
Expand Down
23 changes: 17 additions & 6 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The Iteration/Expression Tree (IET) hierarchy."""

import abc
import ctypes
import inspect
from functools import cached_property
from collections import OrderedDict, namedtuple
Expand Down Expand Up @@ -1030,6 +1031,8 @@ class Dereference(ExprStmt, Node):
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
* `pointer` is an ArrayObject representing a pointer to a C struct, and
`pointee` is a field in `pointer`.
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
`pointee` is a Symbol representing the dereferenced value.
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
"""

is_Dereference = True
Expand All @@ -1048,13 +1051,21 @@ def functions(self):

@property
def expr_symbols(self):
ret = [self.pointer.indexed]
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret = []
if self.pointer.is_Symbol:
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
"Scalar dereference must have a pointer ctype"
ret.append(self.pointer._C_symbol)
enwask marked this conversation as resolved.
Show resolved Hide resolved
ret.append(self.pointee._C_symbol)
else:
ret.append(self.pointer.indexed)
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
enwask marked this conversation as resolved.
Show resolved Hide resolved
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols
for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret.append(self.pointee._C_symbol)
return tuple(filter_ordered(ret))

@property
Expand Down
10 changes: 8 additions & 2 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from devito.ir.support.space import Backward
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
ListInitializer, ccode, uxreplace)
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types.basic import AbstractFunction, AbstractSymbol, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)

Expand Down Expand Up @@ -208,7 +209,7 @@ def _gen_struct_decl(self, obj, masked=()):
while issubclass(ctype, ctypes._Pointer):
ctype = ctype._type_

if not issubclass(ctype, ctypes.Structure):
if not issubclass(ctype, ctypes.Structure) or issubclass(ctype, NoDeclStruct):
enwask marked this conversation as resolved.
Show resolved Hide resolved
return None
except TypeError:
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
Expand Down Expand Up @@ -453,6 +454,9 @@ def visit_Dereference(self, o):
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
elif a1.is_Symbol:
enwask marked this conversation as resolved.
Show resolved Hide resolved
rvalue = '*%s' % a1.name
lvalue = self._gen_value(a0, 0)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
lvalue = self._gen_value(a0, 0)
Expand Down Expand Up @@ -957,6 +961,7 @@ def default_retval(cls):
Drive the search. Accepted:
- `symbolics`: Collect all AbstractFunction objects, default
- `basics`: Collect all Basic objects
- `scalars`: Collect all AbstractSymbol objects
enwask marked this conversation as resolved.
Show resolved Hide resolved
- `dimensions`: Collect all Dimensions
- `indexeds`: Collect all Indexed objects
- `indexedbases`: Collect all IndexedBase objects
Expand All @@ -977,6 +982,7 @@ def _defines_aliases(n):
rules = {
'symbolics': lambda n: n.functions,
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
'scalars': lambda n: [i for i in n.expr_symbols if isinstance(i, AbstractSymbol)],
enwask marked this conversation as resolved.
Show resolved Hide resolved
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
'indexedbases': lambda n: [i for i in n.expr_symbols
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_complex
from devito.passes.iet.dtypes import lower_dtypes
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -414,7 +414,7 @@ def place_casts(self, iet, **kwargs):

@iet_pass
def make_langtypes(self, iet):
iet, metadata = lower_complex(iet, self.lang, self.compiler)
iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry)
return iet, metadata

def process(self, graph):
Expand Down
70 changes: 47 additions & 23 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,53 @@
import ctypes

from devito.ir import FindSymbols, Uxreplace
from devito.ir.iet.nodes import Dereference
from devito.tools.utils import as_list
from devito.types.basic import Symbol

__all__ = ['lower_complex']
__all__ = ['lower_dtypes']


def lower_complex(iet, lang, compiler):
def lower_dtypes(iet, lang, compiler, sregistry):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can this not be an @iet_pass and still work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called from an @iet_pass in definitions here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah OK , I maybe have told Mathias already, but imho that thing goes straight into operator/operator.py::Operator::_lower_iet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, I can make this change

"""
Lower language-specific dtypes and add headers for complex arithmetic
"""
# Include complex headers if needed (before we replace complex dtypes)
metadata = _complex_includes(iet, lang, compiler)

body_prefix = [] # Derefs to prepend to the body
body_mapper = {}
enwask marked this conversation as resolved.
Show resolved Hide resolved
params_mapper = {}

# Lower scalar float16s to pointers and dereference them
if lang.get('half_types') is not None:
enwask marked this conversation as resolved.
Show resolved Hide resolved
half, half_p = lang['half_types'] # dtype mappings for half float

for s in FindSymbols('scalars').visit(iet):
if s.dtype != np.float16 or s not in iet.parameters:
continue

ptr = s._rebuild(dtype=half_p, is_const=True)
val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half,
is_const=s.is_const)

params_mapper[s], body_mapper[s] = ptr, val
body_prefix.append(Dereference(val, ptr)) # val = *ptr

# Lower remaining language-specific dtypes
for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
if s.dtype in lang['types'] and s not in params_mapper:
body_mapper[s] = params_mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

# Apply the dtype replacements
body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body))
enwask marked this conversation as resolved.
Show resolved Hide resolved
params = Uxreplace(params_mapper).visit(iet.parameters)

iet = iet._rebuild(body=body, parameters=params)
return iet, metadata


def _complex_includes(iet, lang, compiler):
"""
Add headers for complex arithmetic
"""
Expand All @@ -15,11 +57,11 @@ def lower_complex(iet, lang, compiler):
if not issubclass(f.dtype, ctypes._Pointer)}

if not any(np.issubdtype(d, np.complexfloating) for d in types):
return iet, {}
return {}

metadata = {}
lib = (lang['header-complex'],)

metadata = {}
if lang.get('complex-namespace') is not None:
metadata['namespaces'] = lang['complex-namespace']

Expand All @@ -31,24 +73,6 @@ def lower_complex(iet, lang, compiler):
ff.write(str(lang['def-complex']))
lib += (str(hfile),)

iet = _complex_dtypes(iet, lang)
metadata['includes'] = lib

return iet, metadata


def _complex_dtypes(iet, lang):
"""
Lower dtypes to language specific types
"""
mapper = {}

for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
if s.dtype in lang['types']:
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

body = Uxreplace(mapper).visit(iet.body)
params = Uxreplace(mapper).visit(iet.parameters)
iet = iet._rebuild(body=body, parameters=params)

return iet
return metadata
31 changes: 23 additions & 8 deletions devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import ctypes as ct
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.langbase import LangBB
from devito.symbolics.extended_dtypes import (c_complex, c_double_complex,
c_half, c_half_p)
from devito.tools.dtypes_lowering import ctypes_vector_mapper


__all__ = ['CBB', 'CDataManager', 'COrchestrator']
__all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p']


class CCFloat(np.complex64):
Expand All @@ -19,11 +20,24 @@ class CCDouble(np.complex128):
pass


c_complex = type('_Complex float', (ct.c_double,), {})
c_double_complex = type('_Complex double', (ct.c_longdouble,), {})
class CHalf(np.float16):
pass


class CHalfP(np.float16):
pass


c99_complex = type('_Complex float', (c_complex,), {})
c99_double_complex = type('_Complex double', (c_double_complex,), {})

c_float16 = type('_Float16', (c_half,), {})
c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16})

ctypes_vector_mapper[CCFloat] = c_complex
ctypes_vector_mapper[CCDouble] = c_double_complex
ctypes_vector_mapper[CCFloat] = c99_complex
ctypes_vector_mapper[CCDouble] = c99_double_complex
ctypes_vector_mapper[CHalf] = c_float16
ctypes_vector_mapper[CHalfP] = c_float16_p


class CBB(LangBB):
Expand All @@ -40,9 +54,10 @@ class CBB(LangBB):
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
# Complex and float16
'header-complex': 'complex.h',
'types': {np.complex128: CCDouble, np.complex64: CCFloat},
'types': {np.complex128: CCDouble, np.complex64: CCFloat, np.float16: CHalf},
'half_types': (CHalf, CHalfP),
}


Expand Down
22 changes: 17 additions & 5 deletions devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import ctypes as ct
import numpy as np

from devito.ir import Call, UsingNamespace
from devito.passes.iet.langbase import LangBB
from devito.passes.iet.languages.C import c_float16, c_float16_p
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
from devito.tools.dtypes_lowering import ctypes_vector_mapper

__all__ = ['CXXBB']
Expand Down Expand Up @@ -53,12 +54,21 @@ class CXXCDouble(np.complex128):
pass


cxx_complex = type('std::complex<float>', (ct.c_double,), {})
cxx_double_complex = type('std::complex<double>', (ct.c_longdouble,), {})
class CXXHalf(np.float16):
pass


class CXXHalfP(np.float16):
pass


cxx_complex = type('std::complex<float>', (c_complex,), {})
cxx_double_complex = type('std::complex<double>', (c_double_complex,), {})

ctypes_vector_mapper[CXXCFloat] = cxx_complex
ctypes_vector_mapper[CXXCDouble] = cxx_double_complex
ctypes_vector_mapper[CXXHalf] = c_float16
ctypes_vector_mapper[CXXHalfP] = c_float16_p


class CXXBB(LangBB):
Expand All @@ -75,9 +85,11 @@ class CXXBB(LangBB):
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
# Complex and float16
'header-complex': 'complex',
'complex-namespace': [UsingNamespace('std::complex_literals')],
'def-complex': std_arith,
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat},
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat,
np.float16: CXXHalf},
'half_types': (CXXHalf, CXXHalfP),
enwask marked this conversation as resolved.
Show resolved Hide resolved
}
Loading