Skip to content

Commit

Permalink
Merge pull request #112 from xdslproject/add_tan
Browse files Browse the repository at this point in the history
compiler: Add tan
  • Loading branch information
georgebisbas authored Jul 17, 2024
2 parents ca681a4 + 1bdf1f1 commit 0d2fd87
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
15 changes: 9 additions & 6 deletions devito/ir/xdsl_iet/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from sympy import (Add, And, Expr, Float, GreaterThan, Indexed, Integer, LessThan,
Number, Pow, StrictGreaterThan, StrictLessThan, Symbol, floor,
Mul, sin, cos)
Mul, sin, cos, tan)
from sympy.core.relational import Relational
from sympy.logic.boolalg import BooleanFunction
from devito.operations.interpolators import Injection
Expand Down Expand Up @@ -288,17 +288,22 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
for arg in node.args)
return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs)

# Trigonometric functions
elif isinstance(node, sin):
assert len(node.args) == 1, "Expected single argument for sin."
return math.SinOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, cos):
assert len(node.args) == 1, "Expected single argument for cos."

assert len(node.args) == 1, "Expected single argument for cos."
return math.CosOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, tan):
assert len(node.args) == 1, "Expected single argument for TanOp."

return math.TanOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, Relational):
if isinstance(node, GreaterThan):
Expand All @@ -311,9 +316,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
mnemonic = "slt"
else:
raise NotImplementedError(f"Unimplemented comparison {type(node)}")

# import pdb;
# pdb.set_trace()

SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args)
# Operands must have the same type
# TODO: look at here if index stuff does not make sense
Expand Down
43 changes: 35 additions & 8 deletions tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pytest

from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve,
Operator, Eq, Constant, norm, SpaceDimension, switchconfig, sin, cos)
from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad,
diag, solve, Operator, Eq, Constant, norm, SpaceDimension,
switchconfig, sin, cos, tan)
from devito.types import Array, Function, TimeFunction
from devito.tools import as_tuple

Expand All @@ -14,6 +15,7 @@
from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp
from xdsl.dialects.llvm import LLVMPointerType
from xdsl.dialects.memref import Load
from xdsl.dialects.experimental import math


def test_xdsl_I():
Expand Down Expand Up @@ -980,10 +982,13 @@ def test_sine(self, deg, exp):
u = Function(name="u", grid=grid)
u.data[:, :] = 0

eq0 = Eq(u, sin(deg))
deg0 = Constant(name='deg', value=deg)
eq0 = Eq(u, sin(deg0))

op = Operator([eq0], opt='xdsl')
op.apply()
opx = Operator([eq0], opt='xdsl')
opx.apply()

assert len([op for op in opx._module.walk() if isinstance(op, math.SinOp)]) == 1
assert np.isclose(norm(u), exp, rtol=1e-4)

@pytest.mark.parametrize('deg, exp', ([90.0, 1.7922944], [30.0, 0.6170056],
Expand All @@ -994,10 +999,32 @@ def test_cosine(self, deg, exp):
u = Function(name="u", grid=grid)
u.data[:, :] = 0

eq0 = Eq(u, cos(deg))
deg0 = Constant(name='deg', value=deg)
eq0 = Eq(u, cos(deg0))

opx = Operator([eq0], opt='xdsl')
opx.apply()

assert len([op for op in opx._module.walk() if isinstance(op, math.CosOp)]) == 1

assert np.isclose(norm(u), exp, rtol=1e-4)

@pytest.mark.parametrize('deg, exp', ([2.0, 8.74016], [30.0, 25.621325],
[45.0, 6.4791]))
def test_tan(self, deg, exp):
grid = Grid(shape=(4, 4))

u = Function(name="u", grid=grid)
u.data[:, :] = 0

deg0 = Constant(name='deg', value=deg)
eq0 = Eq(u, tan(deg0))

opx = Operator([eq0], opt='xdsl')
opx.apply()

assert len([op for op in opx._module.walk() if isinstance(op, math.TanOp)]) == 1

op = Operator([eq0], opt='xdsl')
op.apply()
assert np.isclose(norm(u), exp, rtol=1e-4)


Expand Down

0 comments on commit 0d2fd87

Please sign in to comment.