Skip to content

Commit

Permalink
Merge pull request #2253 from devitocodes/vector-types-push
Browse files Browse the repository at this point in the history
compiler: Machinery to generate vector types
  • Loading branch information
mloubout authored Oct 31, 2023
2 parents fafa58a + cdb83e4 commit d4ebfa9
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 73 deletions.
74 changes: 39 additions & 35 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
IndexedPointer, Macro, cast_mapper, subs_op_args)
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_to_ctype,
flatten, generator, is_integer, split)
from devito.types import (Array, Bundle, Dimension, Eq, Symbol, LocalObject,
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
CompositeObject, CustomDimension)

__all__ = ['HaloExchangeBuilder', 'mpi_registry']
Expand Down Expand Up @@ -291,22 +291,22 @@ def _make_bundles(self, hs):

mapper = as_mapper(halo_scheme.fmapper, lambda i: halo_scheme.fmapper[i])
for hse, components in mapper.items():
# We recast everything as Bundles for simplicity -- worst case scenario
# all Bundles only have one component. Existing Bundles are preserved
# We recast everything as Bags for simplicity -- worst case scenario
# all Bags only have one component. Existing Bundles are preserved
halo_scheme = halo_scheme.drop(components)
bundles, candidates = split(tuple(components), lambda i: i.is_Bundle)
for b in bundles:
halo_scheme = halo_scheme.add(b, hse)

try:
name = "bundle_%s" % "".join(f.name for f in candidates)
bundle = Bundle(name=name, components=candidates)
halo_scheme = halo_scheme.add(bundle, hse)
name = "bag_%s" % "".join(f.name for f in candidates)
bag = Bag(name=name, components=candidates)
halo_scheme = halo_scheme.add(bag, hse)
except ValueError:
for i in candidates:
name = "bundle_%s" % i.name
bundle = Bundle(name=name, components=i)
halo_scheme = halo_scheme.add(bundle, hse)
name = "bag_%s" % i.name
bag = Bag(name=name, components=i)
halo_scheme = halo_scheme.add(bag, hse)

hs = hs._rebuild(halo_scheme=halo_scheme)

Expand Down Expand Up @@ -362,13 +362,17 @@ def _make_copy(self, f, hse, key, swap=False):
else:
swap = lambda i, j: (j, i)
name = 'scatter%s' % key
for i, c in enumerate(f.components):
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
if isinstance(f, Bag):
for i, c in enumerate(f.components):
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
else:
for i in range(f.ncomp):
eqns.append(Eq(*swap(buf[[i] + bdims], f[[i] + findices])))

# Compile `eqns` into an IET via recursive compilation
irs, _ = self.rcompile(eqns)

parameters = [buf] + bshape + list(f.components) + ofs
parameters = [buf] + bshape + list(f.handles) + ofs

return CopyBuffer(name, irs.uiet, parameters)

Expand All @@ -391,9 +395,9 @@ def _make_sendrecv(self, f, hse, key, **kwargs):

shape = [d.symbolic_size for d in dims]

arguments = [bufg] + shape + list(f.components) + ofsg
arguments = [bufg] + shape + list(f.handles) + ofsg
gather = Gather('gather%s' % key, arguments)
arguments = [bufs] + shape + list(f.components) + ofss
arguments = [bufs] + shape + list(f.handles) + ofss
scatter = Scatter('scatter%s' % key, arguments)

# The `gather` is unnecessary if sending to MPI.PROC_NULL
Expand All @@ -415,13 +419,13 @@ def _make_sendrecv(self, f, hse, key, **kwargs):

iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])

parameters = (list(f.components) + shape + ofsg + ofss +
parameters = (list(f.handles) + shape + ofsg + ofss +
[fromrank, torank, comm])

return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs)

def _call_sendrecv(self, name, *args, **kwargs):
args = list(args[0].components) + flatten(args[1:])
args = list(args[0].handles) + flatten(args[1:])
return Call(name, args)

def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
Expand Down Expand Up @@ -475,14 +479,14 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):

iet = List(body=body)

parameters = list(f.components) + [comm, nb] + list(fixed.values())
parameters = list(f.handles) + [comm, nb] + list(fixed.values())

return HaloUpdate('haloupdate%s' % key, iet, parameters)

def _call_haloupdate(self, name, f, hse, *args):
comm = f.grid.distributor._obj_comm
nb = f.grid.distributor._obj_neighborhood
args = list(f.components) + [comm, nb] + list(hse.loc_indices.values())
args = list(f.handles) + [comm, nb] + list(hse.loc_indices.values())
return HaloUpdateCall(name, flatten(args))

def _make_compute(self, *args):
Expand Down Expand Up @@ -567,7 +571,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):

iet = List(body=body)

parameters = list(f.components) + [comm, nb] + list(fixed.values())
parameters = list(f.handles) + [comm, nb] + list(fixed.values())

return HaloUpdate('haloupdate%s' % key, iet, parameters)

Expand Down Expand Up @@ -614,7 +618,7 @@ def _make_sendrecv(self, f, hse, key, msg=None):
sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
for i in range(len(f._dist_dimensions))]

arguments = [cast(bufg)] + sizes + list(f.components) + ofsg
arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg
gather = Gather('gather%s' % key, arguments)
# The `gather` is unnecessary if sending to MPI.PROC_NULL
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
Expand All @@ -629,7 +633,7 @@ def _make_sendrecv(self, f, hse, key, msg=None):

iet = List(body=[recv, gather, send])

parameters = list(f.components) + ofsg + [fromrank, torank, comm, msg]
parameters = list(f.handles) + ofsg + [fromrank, torank, comm, msg]

return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs)

Expand All @@ -639,7 +643,7 @@ def _call_sendrecv(self, name, *args, msg=None, haloid=None):
# to collect and scatter the result of an MPI_Irecv
f, _, ofsg, _, fromrank, torank, comm = args
msg = Byref(IndexedPointer(msg, haloid))
return Call(name, list(f.components) + ofsg + [fromrank, torank, comm, msg])
return Call(name, list(f.handles) + ofsg + [fromrank, torank, comm, msg])

def _make_haloupdate(self, f, hse, key, sendrecv, msg=None):
iet = super()._make_haloupdate(f, hse, key, sendrecv, msg=msg)
Expand Down Expand Up @@ -676,7 +680,7 @@ def _make_wait(self, f, hse, key, msg=None):

sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
for i in range(len(f._dist_dimensions))]
arguments = [cast(bufs)] + sizes + list(f.components) + ofss
arguments = [cast(bufs)] + sizes + list(f.handles) + ofss
scatter = Scatter('scatter%s' % key, arguments)

# The `scatter` must be guarded as we must not alter the halo values along
Expand All @@ -690,7 +694,7 @@ def _make_wait(self, f, hse, key, msg=None):

iet = List(body=[waitsend, waitrecv, scatter])

parameters = (list(f.components) + ofss + [fromrank, msg])
parameters = (list(f.handles) + ofss + [fromrank, msg])

return Callable('wait_%s' % key, iet, 'void', parameters, ('static',))

Expand All @@ -712,18 +716,18 @@ def _make_halowait(self, f, hse, key, wait, msg=None):

msgi = Byref(IndexedPointer(msg, len(body)))

arguments = list(f.components) + ofss + [fromrank, msgi]
arguments = list(f.handles) + ofss + [fromrank, msgi]
body.append(Call(wait.name, arguments))

iet = List(body=body)

parameters = list(f.components) + list(fixed.values()) + [nb, msg]
parameters = list(f.handles) + list(fixed.values()) + [nb, msg]

return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))

def _call_halowait(self, name, f, hse, msg):
nb = f.grid.distributor._obj_neighborhood
arguments = list(f.components) + list(hse.loc_indices.values()) + [nb, msg]
arguments = list(f.handles) + list(hse.loc_indices.values()) + [nb, msg]
return HaloWaitCall(name, arguments)

def _make_remainder(self, hs, key, callcompute, *args):
Expand Down Expand Up @@ -789,7 +793,7 @@ def _make_haloupdate(self, f, hse, key, *args, msg=None):
ofsg = [fixed.get(d) or ofsg.pop(0) for d in f.dimensions]

# The `gather` is unnecessary if sending to MPI.PROC_NULL
arguments = [cast(bufg)] + sizes + list(f.components) + ofsg
arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg
gather = Gather('gather%s' % key, arguments)
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)

Expand All @@ -805,12 +809,12 @@ def _make_haloupdate(self, f, hse, key, *args, msg=None):
# The -1 below is because an Iteration, by default, generates <=
ncomms = Symbol(name='ncomms')
iet = Iteration([recv, gather, send], dim, ncomms - 1)
parameters = f.components + (comm, msg, ncomms) + tuple(fixed.values())
parameters = f.handles + (comm, msg, ncomms) + tuple(fixed.values())
return HaloUpdate('haloupdate%s' % key, iet, parameters)

def _call_haloupdate(self, name, f, hse, msg):
comm = f.grid.distributor._obj_comm
args = f.components + (comm, msg, msg.npeers) + tuple(hse.loc_indices.values())
args = f.handles + (comm, msg, msg.npeers) + tuple(hse.loc_indices.values())
return HaloUpdateCall(name, args)

def _make_halowait(self, f, hse, key, *args, msg=None):
Expand All @@ -834,7 +838,7 @@ def _make_halowait(self, f, hse, key, *args, msg=None):

# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
arguments = [cast(bufs)] + sizes + list(f.components) + ofss
arguments = [cast(bufs)] + sizes + list(f.handles) + ofss
scatter = Scatter('scatter%s' % key, arguments)
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)

Expand All @@ -846,11 +850,11 @@ def _make_halowait(self, f, hse, key, *args, msg=None):
# The -1 below is because an Iteration, by default, generates <=
ncomms = Symbol(name='ncomms')
iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1)
parameters = f.components + tuple(fixed.values()) + (msg, ncomms)
parameters = f.handles + tuple(fixed.values()) + (msg, ncomms)
return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))

def _call_halowait(self, name, f, hse, msg):
args = f.components + tuple(hse.loc_indices.values()) + (msg, msg.npeers)
args = f.handles + tuple(hse.loc_indices.values()) + (msg, msg.npeers)
return HaloWaitCall(name, args)

def _make_wait(self, *args, **kwargs):
Expand Down Expand Up @@ -1066,7 +1070,7 @@ class MPICall(Call):
@property
def ncomps(self):
"""
The number of Bundle components this MPICall was constructed for.
The number of components this MPICall was constructed for.
"""
return len([f for f in self.functions if f.is_DiscreteFunction])

Expand Down Expand Up @@ -1213,7 +1217,7 @@ def _arg_defaults(self, allocator, alias, args=None):

def _arg_values(self, args=None, **kwargs):
# Any will do
for f in self.target.components:
for f in self.target.handles:
try:
alias = kwargs[f.name]
break
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def process(self, graph):
self.place_transfers(graph, mapper=mapper)
self.place_definitions(graph, globs=set())
self.place_devptr(graph)
self.place_bundling(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)


Expand Down
55 changes: 35 additions & 20 deletions devito/passes/iet/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from devito.ir.support import SymbolRegistry
from devito.mpi.distributed import MPINeighborhood
from devito.tools import DAG, as_tuple, filter_ordered, timed_pass
from devito.types import (Array, CompositeObject, Lock, IncrDimension, Indirection,
Temp)
from devito.types import (Array, Bundle, CompositeObject, Lock, IncrDimension,
Indirection, Temp)
from devito.types.args import ArgProvider
from devito.types.dense import DiscreteFunction
from devito.types.dimension import AbstractIncrDimension, BlockDimension
Expand All @@ -19,14 +19,19 @@
class Graph(object):

"""
A special DAG representing call graphs.
DAG representation of a call graph.
The nodes of the graph are IET Callables; an edge from node `a` to node `b`
indicates that `b` calls `a`.
The nodes of the DAG are IET Callables.
An edge from node `a` to node `b` indicates that `b` calls `a`.
The `apply` method may be used to visit the Graph and apply a transformer `T`
to all nodes. This may change the state of the Graph: node `a` gets replaced
by `a' = T(a)`; new nodes (Callables), and therefore new edges, may be added.
The `apply` method visits the Graph and applies a transformation `T`
to all nodes. This may change the state of the Graph:
* Node `a` gets replaced by `a' = T(a)`.
* New nodes (Callables) may be added.
* Consequently, edges are added.
* New global objects may be introduced:
* Global symbols, header files, ...
The `visit` method collects info about the nodes in the Graph.
"""
Expand All @@ -40,6 +45,10 @@ def __init__(self, iet, sregistry=None):
self.headers = []
self.globals = []

# Stash immutable information useful for some compiler passes
writes = FindSymbols('writes').visit(iet)
self.writes_input = frozenset(f for f in writes if f.is_Input)

@property
def root(self):
return self.efuncs[list(self.efuncs).pop(0)]
Expand Down Expand Up @@ -295,6 +304,7 @@ def _(i, mapper, sregistry):


@abstract_object.register(Array)
@abstract_object.register(Bundle)
def _(i, mapper, sregistry):
if isinstance(i, Lock):
name = sregistry.make_name(prefix='lock')
Expand All @@ -308,6 +318,8 @@ def _(i, mapper, sregistry):
i.indexed: v.indexed,
i._C_symbol: v._C_symbol,
})
if i.dmap is not None:
mapper[i.dmap] = v.dmap


@abstract_object.register(CompositeObject)
Expand Down Expand Up @@ -406,24 +418,32 @@ def update_args(root, efuncs, dag):

# The parameters/arguments lists may have changed since a pass may have:
# 1) introduced a new symbol
new_args = derive_parameters(root)
new_params = derive_parameters(root)

# 2) defined a symbol for which no definition was available yet (e.g.
# via a malloc, or a Dereference)
defines = FindSymbols('defines').visit(root.body)
drop_args = [a for a in root.parameters if a in defines]
drop_params = [a for a in root.parameters if a in defines]

# 3) removed a symbol that was previously necessary (e.g., `x_size` after
# linearization)
symbols = FindSymbols('basics').visit(root.body)
drop_args.extend(a for a in root.parameters if a.is_Symbol and a not in symbols)
drop_params.extend(a for a in root.parameters
if a.is_Symbol and a not in symbols)

if not (new_args or drop_args):
# Must record the index, not the param itself, since a param may be
# bound to whatever arg, possibly a generic SymPy expr
drop_params = [root.parameters.index(a) for a in drop_params]

if not (new_params or drop_params):
return efuncs

# Create the new parameters and arguments lists

def _filter(v, efunc=None):
processed = list(v)
for a in new_args:
processed = [a for i, a in enumerate(v) if i not in drop_params]

for a in new_params:
if a in processed:
# A child efunc trying to add a symbol alredy added by a
# sibling efunc
Expand All @@ -437,15 +457,10 @@ def _filter(v, efunc=None):

processed.append(a)

processed = [a for a in processed if a not in drop_args]

return processed

efuncs = OrderedDict(efuncs)

# Update to use the new signature
parameters = _filter(root.parameters, root)
efuncs[root.name] = root._rebuild(parameters=parameters)
efuncs[root.name] = root._rebuild(parameters=_filter(root.parameters, root))

# Update all call sites to use the new signature
for n in dag.downstream(root.name):
Expand Down
Loading

0 comments on commit d4ebfa9

Please sign in to comment.