Skip to content

Commit

Permalink
Merge pull request #119 from xdslproject/emilien/return-type
Browse files Browse the repository at this point in the history
Implement stencil.return operand type promotion
  • Loading branch information
georgebisbas authored Aug 30, 2024
2 parents ecd3900 + 06293fe commit 5c31f3c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 39 deletions.
37 changes: 24 additions & 13 deletions devito/ir/xdsl_iet/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
SSAargs = (self._visit_math_nodes(dim, arg, output_indexed)
for arg in node.args)
return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs)

# Trigonometric functions
elif isinstance(node, sin):
assert len(node.args) == 1, "Expected single argument for sin."
Expand All @@ -298,13 +298,13 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
assert len(node.args) == 1, "Expected single argument for cos."
return math.CosOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, tan):
assert len(node.args) == 1, "Expected single argument for TanOp."

return math.TanOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, Relational):
if isinstance(node, GreaterThan):
mnemonic = "sge"
Expand Down Expand Up @@ -382,7 +382,20 @@ def build_stencil_step(self, dim: SteppingDimension, eq: LoweredEq) -> None:
self.function_values |= self.apply_temps

with ImplicitBuilder(apply.region.block):
stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, eq.lhs)])
result = self._visit_math_nodes(dim, eq.rhs, eq.lhs)
expected_type = apply.res[0].type.get_element_type()
match expected_type:
case result.type:
pass
case builtin.f32:
if result.type == IndexType():
result = arith.IndexCastOp(result, builtin.i64).result
result = arith.SIToFPOp(result, builtin.f32).result
case builtin.IndexType:
result = arith.IndexCastOp(result, IndexType()).result
case _:
raise Exception(f"Unexpected result type {type(result)}")
stencil.ReturnOp.get([result])

lb = stencil.IndexAttr.get(*([0] * len(shape)))
ub = stencil.IndexAttr.get(*shape)
Expand Down Expand Up @@ -439,7 +452,6 @@ def build_condition(self, dim: SteppingDimension, eq: BooleanFunction):
self.build_generic_step_expression(dim, eq)
scf.Yield()


def build_time_loop(
self, eqs: list[Any], step_dim: SteppingDimension, **kwargs
):
Expand All @@ -450,7 +462,7 @@ def build_time_loop(
ub = iet_ssa.LoadSymbolic.get(
step_dim.symbolic_max._C_name, IndexType()
)

one = arith.Constant.from_int_and_width(1, IndexType())

# Devito iterates from time_m to time_M *inclusive*, MLIR only takes
Expand Down Expand Up @@ -497,7 +509,7 @@ def build_time_loop(
for i, (f, t) in enumerate(self.time_buffers)
}
self.function_values |= self.block_args

# Name the block argument for debugging
for (f, t), arg in self.block_args.items():
arg.name_hint = f"{f.name}_t{t}"
Expand All @@ -513,8 +525,7 @@ def build_time_loop(

def lower_devito_Eqs(self, eqs: list[Any], **kwargs):
# Lower devito Equations to xDSL



for eq in eqs:
lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs)
if isinstance(eq, Eq):
Expand Down Expand Up @@ -546,7 +557,7 @@ def _lower_injection(self, eqs: list[LoweredEq]):
lb = arith.Constant.from_int_and_width(int(lower), IndexType())
else:
raise NotImplementedError(f"Lower bound of type {type(lower)} not supported")

try:
name = interval.dim.symbolic_min.name
except:
Expand Down Expand Up @@ -633,7 +644,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp:
# Instantiate the module.
self.function_values: dict[tuple[Function, int], SSAValue] = {}
self.symbol_values: dict[str, SSAValue] = {}

module = ModuleOp(Region([block := Block([])]))
with ImplicitBuilder(block):
# Get all functions used in the equations
Expand All @@ -647,7 +658,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp:
functions.add(f.function)

elif isinstance(eq, Injection):

functions.add(eq.field.function)
for f in retrieve_functions(eq.expr):
if isinstance(f, PointSource):
Expand Down
49 changes: 23 additions & 26 deletions tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,20 @@ def test_function_IV():
assert np.isclose(norm(u), devito_norm_u)


def test_function_V():
grid = Grid(shape=(5, 5))
x, y = grid.dimensions

f = Function(name="f", grid=grid)

eqns = [Eq(f, 2)]

op = Operator(eqns, opt="xdsl")
op.apply()

assert np.all(f.data == 2)


class TestTrigonometric(object):

@pytest.mark.parametrize('deg, exp', ([90.0, 3.5759869], [30.0, 3.9521265],
Expand Down Expand Up @@ -1028,37 +1042,20 @@ def test_tan(self, deg, exp):
assert np.isclose(norm(u), exp, rtol=1e-4)


class TestOperatorUnsupported(object):
def test_forward_assignment():
# simple forward assignment

@pytest.mark.xfail(reason="stencil.return operation does not verify for i64")
def test_forward_assignment(self):
# simple forward assignment

grid = Grid(shape=(4, 4))
u = TimeFunction(name="u", grid=grid, space_order=2)
u.data[:, :, :] = 0

eq0 = Eq(u.forward, 1)

op = Operator([eq0], opt='xdsl')

op.apply(time_M=1)

assert np.isclose(norm(u), 5.6584, rtol=0.001)

@pytest.mark.xfail(reason="stencil.return operation does not verify for i64")
def test_function(self):
grid = Grid(shape=(5, 5))
x, y = grid.dimensions
grid = Grid(shape=(4, 4))
u = TimeFunction(name="u", grid=grid, space_order=2)
u.data[:, :, :] = 0

f = Function(name="f", grid=grid)
eq0 = Eq(u.forward, 1)

eqns = [Eq(f, 2)]
op = Operator([eq0], opt='xdsl')

op = Operator(eqns, opt='xdsl')
op.apply()
op.apply(time_M=1)

assert np.all(f.data == 4)
assert np.isclose(norm(u), 5.6584, rtol=0.001)


class TestElastic():
Expand Down

0 comments on commit 5c31f3c

Please sign in to comment.