Skip to content

Commit

Permalink
Merge pull request #2244 from devitocodes/patch-lazy-threading
Browse files Browse the repository at this point in the history
compiler: Introduce symbolic fencing
  • Loading branch information
FabioLuporini authored Oct 25, 2023
2 parents ae40de0 + a000c49 commit 522d475
Show file tree
Hide file tree
Showing 10 changed files with 334 additions and 47 deletions.
56 changes: 43 additions & 13 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from devito.ir.equations import ClusterizedEq
from devito.ir.support import (PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext,
Forward, Interval, IntervalGroup, IterationSpace,
DataSpace, Guards, Properties, Scope, detect_accesses,
detect_io, normalize_properties, normalize_syncs,
minimum, maximum, null_ispace)
DataSpace, Guards, Properties, Scope, WithLock,
PrefetchUpdate, detect_accesses, detect_io,
normalize_properties, normalize_syncs, minimum,
maximum, null_ispace)
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.symbolics import estimate_cost
from devito.tools import as_tuple, flatten, frozendict, infer_dtype
from devito.types import WeakFence, CriticalRegion

__all__ = ["Cluster", "ClusterGroup"]

Expand Down Expand Up @@ -176,10 +178,6 @@ def functions(self):
def has_increments(self):
return any(e.is_Increment for e in self.exprs)

@cached_property
def is_scalar(self):
return not any(f.is_Function for f in self.scope.writes)

@cached_property
def grid(self):
grids = set(f.grid for f in self.functions if f.is_DiscreteFunction) - {None}
Expand All @@ -188,15 +186,21 @@ def grid(self):
else:
raise ValueError("Cluster has no unique Grid")

@cached_property
def is_scalar(self):
return not any(f.is_Function for f in self.scope.writes)

@cached_property
def is_dense(self):
"""
A Cluster is dense if at least one of the following conditions is True:
True if at least one of the following conditions are True:
* It is defined over a unique Grid and all of the Grid Dimensions
are PARALLEL.
* Only DiscreteFunctions are written and only affine index functions
are used (e.g., `a[x+1, y-2]` is OK, while `a[b[x], y-2]` is not)
False in all other cases.
"""
# Hopefully it's got a unique Grid and all Dimensions are PARALLEL (or
# at most PARALLEL_IF_PVT). This is a quick and easy check so we try it first
Expand All @@ -212,21 +216,47 @@ def is_dense(self):
# Fallback to legacy is_dense checks
return (not any(e.conditionals for e in self.exprs) and
not any(f.is_SparseFunction for f in self.functions) and
not self.is_halo_touch and
not self.is_wild and
all(a.is_regular for a in self.scope.accesses))

@cached_property
def is_sparse(self):
"""
A Cluster is sparse if it represents a sparse operation, i.e iff
There's at least one irregular access.
True if it represents a sparse operation, i.e iff there's at least
one irregular access, False otherwise.
"""
return any(a.is_irregular for a in self.scope.accesses)

@property
def is_wild(self):
"""
True if encoding a non-mathematical operation, False otherwise.
"""
return self.is_halo_touch or self.is_fence

@property
def is_halo_touch(self):
return (len(self.exprs) > 0 and
all(isinstance(e.rhs, HaloTouch) for e in self.exprs))
return self.exprs and all(isinstance(e.rhs, HaloTouch) for e in self.exprs)

@property
def is_fence(self):
return self.is_weak_fence or self.is_critical_region

@property
def is_weak_fence(self):
return self.exprs and all(isinstance(e.rhs, WeakFence) for e in self.exprs)

@property
def is_critical_region(self):
return self.exprs and all(isinstance(e.rhs, CriticalRegion) for e in self.exprs)

@property
def is_async(self):
"""
True if an asynchronous Cluster, False otherwise.
"""
return any(isinstance(s, (WithLock, PrefetchUpdate))
for s in flatten(self.syncs.values()))

@cached_property
def dtype(self):
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/clusters/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ def __init__(self, func, mode='dense'):
self.func = func

if mode == 'dense':
self.cond = lambda c: c.is_dense or not c.is_sparse
self.cond = lambda c: (c.is_dense or not c.is_sparse) and not c.is_wild
elif mode == 'sparse':
self.cond = lambda c: c.is_sparse
self.cond = lambda c: c.is_sparse and not c.is_wild
else:
self.cond = lambda c: True

Expand Down
18 changes: 16 additions & 2 deletions devito/ir/stree/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,29 @@ def stree_build(clusters, profiler=None, **kwargs):

def preprocess(clusters, options=None, **kwargs):
"""
Remove the HaloTouch's from `clusters` and create a mapping associating
each removed HaloTouch to the first Cluster necessitating it.
Lower the so-called "wild" Clusters, that is objects not representing a set
of mathematical operations. This boils down to:
* Moving the HaloTouch's from `clusters` into a mapper `M: {HT -> C}`.
`c = M(ht)` is the first Cluster of the sequence requiring the halo
exchange `ht` to have terminated before the execution can proceed.
* Lower the CriticalRegions:
* If they encode an asynchronous operation (e.g., a WaitLock), attach
it to a Nop Cluster for future lowering;
* Otherwise, simply remove them, as they have served their purpose
at this point.
* Remove the WeakFences, as they have served their purpose at this point.
"""
queue = []
processed = []
for c in clusters:
if c.is_halo_touch:
hs = HaloScheme.union(e.rhs.halo_scheme for e in c.exprs)
queue.append(c.rebuild(halo_scheme=hs))
elif c.is_critical_region and c.syncs:
processed.append(c.rebuild(exprs=None, guards=c.guards, syncs=c.syncs))
elif c.is_wild:
continue
else:
dims = set(c.ispace.promote(lambda d: d.is_Block).itdims)

Expand Down
60 changes: 47 additions & 13 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
q_constant, q_affine, q_routine, search, uxreplace)
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
flatten, memoized_meth, memoized_generator)
from devito.types import (Barrier, ComponentAccess, Dimension, DimensionTuple,
Function, Jump, Symbol, Temp, TempArray, TBArray)
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
CriticalRegion, Function, Symbol, Temp, TempArray,
TBArray)

__all__ = ['IterationInstance', 'TimedAccess', 'Scope', 'ExprGeometry']

Expand All @@ -23,10 +24,9 @@ class IndexMode(Tag):
REGULAR = IndexMode('regular')
IRREGULAR = IndexMode('irregular')

mocksym = Symbol(name='⋈')
"""
A Symbol to create mock data depdendencies.
"""
# Symbols to create mock data depdendencies
mocksym0 = Symbol(name='__⋈_0__')
mocksym1 = Symbol(name='__⋈_1__')


class IterationInstance(LabeledVector):
Expand Down Expand Up @@ -848,9 +848,21 @@ def writes_gen(self):

# Objects altering the control flow (e.g., synchronization barriers,
# break statements, ...) are converted into mock dependences

# Fences (any sort) cannot float around upon topological sorting
for i, e in enumerate(self.exprs):
if isinstance(e.rhs, (Barrier, Jump)):
yield TimedAccess(mocksym, 'W', i, e.ispace)
if isinstance(e.rhs, Fence):
yield TimedAccess(mocksym0, 'W', i, e.ispace)

# CriticalRegions are stronger than plain Fences.
# We must also ensure that none of the Eqs within an opening-closing
# CriticalRegion pair floats outside upon topological sorting
for i, e in enumerate(self.exprs):
if isinstance(e.rhs, CriticalRegion) and e.rhs.opening:
for j, e1 in enumerate(self.exprs[i+1:], 1):
if isinstance(e1.rhs, CriticalRegion) and e1.rhs.closing:
break
yield TimedAccess(mocksym1, 'W', i+j, e1.ispace)

@cached_property
def writes(self):
Expand Down Expand Up @@ -904,12 +916,32 @@ def reads_implicit_gen(self):
for i in symbols:
yield TimedAccess(i, 'R', -1)

@memoized_generator
def reads_synchro_gen(self):
"""
Generate all reads due to syncronization operations. These may be explicit
or implicit.
"""
# Objects altering the control flow (e.g., synchronization barriers,
# break statements, ...) are converted into mock dependences

# Fences (any sort) cannot float around upon topological sorting
for i, e in enumerate(self.exprs):
if isinstance(e.rhs, Fence):
if i > 0:
yield TimedAccess(mocksym0, 'R', i-1, e.ispace)
if i < len(self.exprs)-1:
yield TimedAccess(mocksym0, 'R', i+1, e.ispace)

# CriticalRegions are stronger than plain Fences.
# We must also ensure that none of the Eqs within an opening-closing
# CriticalRegion pair floats outside upon topological sorting
for i, e in enumerate(self.exprs):
if isinstance(e.rhs, (Barrier, Jump)):
yield TimedAccess(mocksym, 'R', max(i, 0), e.ispace)
yield TimedAccess(mocksym, 'R', i+1, e.ispace)
if isinstance(e.rhs, CriticalRegion):
if e.rhs.opening and i > 0:
yield TimedAccess(mocksym1, 'R', i-1, self.exprs[i-1].ispace)
elif e.rhs.closing and i < len(self.exprs)-1:
yield TimedAccess(mocksym1, 'R', i+1, self.exprs[i+1].ispace)

@memoized_generator
def reads_gen(self):
Expand All @@ -920,7 +952,9 @@ def reads_gen(self):
# is efficiency. Sometimes we wish to extract all reads to a given
# AbstractFunction, and we know that by construction these can't
# appear among the implicit reads
return chain(self.reads_explicit_gen(), self.reads_implicit_gen())
return chain(self.reads_explicit_gen(),
self.reads_synchro_gen(),
self.reads_implicit_gen())

@memoized_generator
def reads_smart_gen(self, f):
Expand All @@ -939,7 +973,7 @@ def reads_smart_gen(self, f):
the iteration symbols.
"""
if isinstance(f, (Function, Temp, TempArray, TBArray)):
for i in self.reads_explicit_gen():
for i in chain(self.reads_explicit_gen(), self.reads_synchro_gen()):
if f is i.function:
for j in extrema(i.access):
yield TimedAccess(j, i.mode, i.timestamp, i.ispace)
Expand Down
8 changes: 7 additions & 1 deletion devito/passes/clusters/asynchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from devito.ir import (Forward, GuardBoundNext, Queue, Vector, WaitLock, WithLock,
FetchUpdate, PrefetchUpdate, ReleaseLock, normalize_syncs)
from devito.passes.clusters.utils import is_memcpy
from devito.passes.clusters.utils import bind_critical_regions, is_memcpy
from devito.symbolics import IntDiv, uxreplace
from devito.tools import OrderedSet, is_integer, timed_pass
from devito.types import CustomDimension, Lock
Expand Down Expand Up @@ -139,6 +139,12 @@ def callback(self, clusters, prefix):
tasks[c0].append(ReleaseLock(lock[i], target))
tasks[c0].append(WithLock(lock[i], target, i, function, findex, d))

# CriticalRegions preempt WaitLocks, by definition
mapper = bind_critical_regions(clusters)
for c in clusters:
for c1 in mapper.get(c, []):
waits[c].update(waits.pop(c1, []))

processed = []
for c in clusters:
if waits[c] or tasks[c]:
Expand Down
8 changes: 5 additions & 3 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass
from devito.ir.support import (SEQUENTIAL, SEPARABLE, Scope, ReleaseLock,
WaitLock, WithLock, FetchUpdate, PrefetchUpdate)
from devito.passes.clusters.utils import in_critical_region
from devito.symbolics import pow_to_mul
from devito.tools import DAG, Stamp, as_tuple, flatten, frozendict, timed_pass
from devito.types import Hyperplane
Expand Down Expand Up @@ -44,8 +45,9 @@ def callback(self, clusters, prefix):
processed.append(c)
continue

# Synchronization operations prevent lifting
if c.syncs.get(dim):
# Synchronization prevents lifting
if c.syncs.get(dim) or \
in_critical_region(c, clusters):
processed.append(c)
continue

Expand Down Expand Up @@ -262,7 +264,7 @@ def dump():

groups, processed = processed, []
for group in groups:
for flag, minigroup in groupby(group, key=lambda c: c.is_halo_touch):
for flag, minigroup in groupby(group, key=lambda c: c.is_wild):
if flag:
processed.extend([(c,) for c in minigroup])
else:
Expand Down
50 changes: 48 additions & 2 deletions devito/passes/clusters/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from collections import defaultdict

from devito.ir import Cluster
from devito.symbolics import uxreplace
from devito.types import Symbol, Wildcard
from devito.tools import as_tuple, flatten
from devito.types import CriticalRegion, Eq, Symbol, Wildcard

__all__ = ['makeit_ssa', 'is_memcpy']
__all__ = ['makeit_ssa', 'is_memcpy', 'make_critical_sequence',
'bind_critical_regions', 'in_critical_region']


def makeit_ssa(exprs):
Expand Down Expand Up @@ -48,3 +53,44 @@ def is_memcpy(expr):
return False

return a.function.is_Array or b.function.is_Array


def make_critical_sequence(ispace, sequence, **kwargs):
sequence = as_tuple(sequence)
assert len(sequence) >= 1

processed = []

# Opening
expr = Eq(Symbol(name='⋈'), CriticalRegion(True))
processed.append(Cluster(exprs=expr, ispace=ispace, **kwargs))

processed.extend(sequence)

# Closing
expr = Eq(Symbol(name='⋈'), CriticalRegion(False))
processed.append(Cluster(exprs=expr, ispace=ispace, **kwargs))

return processed


def bind_critical_regions(clusters):
"""
A mapper from CriticalRegions to the critical sequences they open.
"""
critical_region = False
mapper = defaultdict(list)
for c in clusters:
if c.is_critical_region:
critical_region = not critical_region and c
elif critical_region:
mapper[critical_region].append(c)
return mapper


def in_critical_region(cluster, clusters):
"""
True if `cluster` is part of a critical sequence, False otherwise.
"""
mapper = bind_critical_regions(clusters)
return cluster in flatten(mapper.values())
Loading

0 comments on commit 522d475

Please sign in to comment.