Skip to content

Commit

Permalink
set operations (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Oda authored Nov 14, 2022
1 parent 5dde4e4 commit 78a6d01
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 2 deletions.
32 changes: 30 additions & 2 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ class BinOpRule:
ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""),
}

# Typeset for BinOp of sets.
_SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = {
**_BIN_OP_RULES,
ast.Sub: BinOpRule(
"", r" \setminus ", "", operand_right=BinOperandRule(force=True)
),
ast.BitAnd: BinOpRule("", r" \cap ", ""),
ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""),
ast.BitOr: BinOpRule("", r" \cup ", ""),
}

_UNARY_OPS: dict[type[ast.unaryop], str] = {
ast.Invert: r"\mathord{\sim} ",
ast.UAdd: "+", # Explicitly adds the $+$ operator.
Expand All @@ -164,6 +175,15 @@ class BinOpRule:
ast.NotIn: r"\notin",
}

# Typeset for Compare of sets.
_SET_COMPARE_OPS: dict[type[ast.cmpop], str] = {
**_COMPARE_OPS,
ast.Gt: r"\supset",
ast.GtE: r"\supseteq",
ast.Lt: r"\subset",
ast.LtE: r"\subseteq",
}

_BOOL_OPS: dict[type[ast.boolop], str] = {
ast.And: r"\land",
ast.Or: r"\lor",
Expand All @@ -181,12 +201,16 @@ class FunctionCodegen(ast.NodeVisitor):
_use_raw_function_name: bool
_use_signature: bool

_bin_op_rules: dict[type[ast.operator], BinOpRule]
_compare_ops: dict[type[ast.cmpop], str]

def __init__(
self,
*,
use_math_symbols: bool = False,
use_raw_function_name: bool = False,
use_signature: bool = True,
use_set_symbols: bool = False,
) -> None:
"""Initializer.
Expand All @@ -197,13 +221,17 @@ def __init__(
or convert it to subscript.
use_signature: Whether to add the function signature before the expression
or not.
use_set_symbols: Whether to use set symbols or not.
"""
self._math_symbol_converter = math_symbols.MathSymbolConverter(
enabled=use_math_symbols
)
self._use_raw_function_name = use_raw_function_name
self._use_signature = use_signature

self._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES
self._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS

def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
Expand Down Expand Up @@ -445,7 +473,7 @@ def _wrap_binop_operand(
def visit_BinOp(self, node: ast.BinOp) -> str:
"""Visit a BinOp node."""
prec = _get_precedence(node)
rule = _BIN_OP_RULES[type(node.op)]
rule = self._bin_op_rules[type(node.op)]
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)
return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"
Expand All @@ -459,7 +487,7 @@ def visit_Compare(self, node: ast.Compare) -> str:
"""Visit a compare node."""
parent_prec = _get_precedence(node)
lhs = self._wrap_operand(node.left, parent_prec)
ops = [_COMPARE_OPS[type(x)] for x in node.ops]
ops = [self._compare_ops[type(x)] for x in node.ops]
rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators]
ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)]
return "{" + lhs + "".join(ops_rhs) + "}"
Expand Down
30 changes: 30 additions & 0 deletions src/latexify/codegen/function_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,33 @@ def test_visit_subscript(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
assert isinstance(tree, ast.Subscript)
assert function_codegen.FunctionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("a - b", r"a \setminus b"),
("a & b", r"a \cap b"),
("a ^ b", r"a \mathbin{\triangle} b"),
("a | b", r"a \cup b"),
],
)
def test_use_set_symbols_binop(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
assert isinstance(tree, ast.BinOp)
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("a < b", r"{a \subset b}"),
("a <= b", r"{a \subseteq b}"),
("a > b", r"{a \supset b}"),
("a >= b", r"{a \supseteq b}"),
],
)
def test_use_set_symbols_compare(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
assert isinstance(tree, ast.Compare)
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex
3 changes: 3 additions & 0 deletions src/latexify/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get_latex(
use_math_symbols: bool = False,
use_raw_function_name: bool = False,
use_signature: bool = True,
use_set_symbols: bool = False,
) -> str:
"""Obtains LaTeX description from the function's source.
Expand All @@ -38,6 +39,7 @@ def get_latex(
or convert it to subscript.
use_signature: Whether to add the function signature before the expression or
not.
use_set_symbols: Whether to use set symbols or not.
Returns:
Generatee LaTeX description.
Expand All @@ -59,6 +61,7 @@ def get_latex(
use_math_symbols=use_math_symbols,
use_raw_function_name=use_raw_function_name,
use_signature=use_signature,
use_set_symbols=use_set_symbols,
).visit(tree)


Expand Down
12 changes: 12 additions & 0 deletions src/latexify/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def f(x):
assert frontend.get_latex(f, use_signature=True) == latex_with_flag


def test_get_latex_use_set_symbols() -> None:
def f(x, y):
return x & y

latex_without_flag = r"\mathrm{f}(x, y) = x \mathbin{\&} y"
latex_with_flag = r"\mathrm{f}(x, y) = x \cap y"

assert frontend.get_latex(f) == latex_without_flag
assert frontend.get_latex(f, use_set_symbols=False) == latex_without_flag
assert frontend.get_latex(f, use_set_symbols=True) == latex_with_flag


def test_function() -> None:
def f(x):
return x
Expand Down

0 comments on commit 78a6d01

Please sign in to comment.