diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py index 28a049b..9e493b0 100644 --- a/src/integration_tests/algorithmic_style_test.py +++ b/src/integration_tests/algorithmic_style_test.py @@ -5,7 +5,7 @@ import textwrap from typing import Any, Callable -from latexify import frontend +from latexify import generate_latex def check_algorithm( @@ -24,7 +24,9 @@ def check_algorithm( # def fn(...): # ... # latexified = get_latex(fn, style=ALGORITHM, **kwargs) - latexified = frontend.get_latex(fn, style=frontend.Style.ALGORITHMIC, **kwargs) + latexified = generate_latex.get_latex( + fn, style=generate_latex.Style.ALGORITHMIC, **kwargs + ) assert latexified == latex diff --git a/src/latexify/__init__.py b/src/latexify/__init__.py index f25f699..b64ca04 100644 --- a/src/latexify/__init__.py +++ b/src/latexify/__init__.py @@ -7,11 +7,11 @@ except Exception: __version__ = "" -from latexify import frontend +from latexify import frontend, generate_latex -Style = frontend.Style +Style = generate_latex.Style -get_latex = frontend.get_latex +get_latex = generate_latex.get_latex function = frontend.function expression = frontend.expression diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py index 3dab383..440b8f2 100644 --- a/src/latexify/frontend.py +++ b/src/latexify/frontend.py @@ -2,145 +2,74 @@ from __future__ import annotations -import enum from collections.abc import Callable from typing import Any, overload -from latexify import codegen -from latexify import config as cfg -from latexify import exceptions, parser, transformers +from latexify import ipython_wrappers -class Style(enum.Enum): - EXPRESSION = "expression" - FUNCTION = "function" - ALGORITHMIC = "algorithmic" +@overload +def algorithmic( + fn: Callable[..., Any], **kwargs: Any +) -> ipython_wrappers.LatexifiedAlgorithm: + ... + + +@overload +def algorithmic( + **kwargs: Any, +) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm]: + ... -def get_latex( - fn: Callable[..., Any], - *, - style: Style = Style.FUNCTION, - config: cfg.Config | None = None, - **kwargs, -) -> str: - """Obtains LaTeX description from the function's source. +def algorithmic( + fn: Callable[..., Any] | None = None, **kwargs: Any +) -> ipython_wrappers.LatexifiedAlgorithm | Callable[ + [Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm +]: + """Attach LaTeX pretty-printing to the given function. + + This function works with or without specifying the target function as the + positional argument. The following two syntaxes works similarly. + - latexify.algorithmic(alg, **kwargs) + - latexify.algorithmic(**kwargs)(alg) Args: - fn: Reference to a function to analyze. - style: Style of the LaTeX description, the default is FUNCTION. - config: Use defined Config object, if it is None, it will be automatic assigned - with default value. - **kwargs: Dict of Config field values that could be defined individually - by users. + fn: Callable to be wrapped. + **kwargs: Arguments to control behavior. See also get_latex(). Returns: - Generated LaTeX description. - - Raises: - latexify.exceptions.LatexifyError: Something went wrong during conversion. + - If `fn` is passed, returns the wrapped function. + - Otherwise, returns the wrapper function with given settings. """ - if style == Style.EXPRESSION: - kwargs["use_signature"] = kwargs.get("use_signature", False) - - merged_config = cfg.Config.defaults().merge(config=config, **kwargs) - - # Obtains the source AST. - tree = parser.parse_function(fn) - - # Applies AST transformations. - - if merged_config.prefixes is not None: - tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree) - if merged_config.identifiers is not None: - tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree) - if merged_config.reduce_assignments: - tree = transformers.AssignmentReducer().visit(tree) - if merged_config.expand_functions is not None: - tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree) - - # Generates LaTeX. - if style == Style.ALGORITHMIC: - return codegen.AlgorithmicCodegen( - use_math_symbols=merged_config.use_math_symbols, - use_set_symbols=merged_config.use_set_symbols, - ).visit(tree) - else: - return codegen.FunctionCodegen( - use_math_symbols=merged_config.use_math_symbols, - use_signature=merged_config.use_signature, - use_set_symbols=merged_config.use_set_symbols, - ).visit(tree) - - -class LatexifiedFunction: - """Function with latex representation.""" - - _fn: Callable[..., Any] - _latex: str | None - _error: str | None - - def __init__(self, fn, **kwargs): - self._fn = fn - try: - self._latex = get_latex(fn, **kwargs) - self._error = None - except exceptions.LatexifyError as e: - self._latex = None - self._error = f"{type(e).__name__}: {str(e)}" - - @property - def __doc__(self): - return self._fn.__doc__ - - @__doc__.setter - def __doc__(self, val): - self._fn.__doc__ = val - - @property - def __name__(self): - return self._fn.__name__ - - @__name__.setter - def __name__(self, val): - self._fn.__name__ = val - - def __call__(self, *args): - return self._fn(*args) - - def __str__(self): - return self._latex if self._latex is not None else self._error - - def _repr_html_(self): - """IPython hook to display HTML visualization.""" - return ( - '' + self._error + "" - if self._error is not None - else None - ) - - def _repr_latex_(self): - """IPython hook to display LaTeX visualization.""" - return ( - r"$$ \displaystyle " + self._latex + " $$" - if self._latex is not None - else self._error - ) + if fn is not None: + return ipython_wrappers.LatexifiedAlgorithm(fn, **kwargs) + + def wrapper(f): + return ipython_wrappers.LatexifiedAlgorithm(f, **kwargs) + + return wrapper @overload -def function(fn: Callable[..., Any], **kwargs: Any) -> LatexifiedFunction: +def function( + fn: Callable[..., Any], **kwargs: Any +) -> ipython_wrappers.LatexifiedFunction: ... @overload -def function(**kwargs: Any) -> Callable[[Callable[..., Any]], LatexifiedFunction]: +def function( + **kwargs: Any, +) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]: ... def function( fn: Callable[..., Any] | None = None, **kwargs: Any -) -> LatexifiedFunction | Callable[[Callable[..., Any]], LatexifiedFunction]: +) -> ipython_wrappers.LatexifiedFunction | Callable[ + [Callable[..., Any]], ipython_wrappers.LatexifiedFunction +]: """Attach LaTeX pretty-printing to the given function. This function works with or without specifying the target function as the positional @@ -157,33 +86,39 @@ def function( - Otherwise, returns the wrapper function with given settings. """ if fn is not None: - return LatexifiedFunction(fn, **kwargs) + return ipython_wrappers.LatexifiedFunction(fn, **kwargs) def wrapper(f): - return LatexifiedFunction(f, **kwargs) + return ipython_wrappers.LatexifiedFunction(f, **kwargs) return wrapper @overload -def expression(fn: Callable[..., Any], **kwargs: Any) -> LatexifiedFunction: +def expression( + fn: Callable[..., Any], **kwargs: Any +) -> ipython_wrappers.LatexifiedFunction: ... @overload -def expression(**kwargs: Any) -> Callable[[Callable[..., Any]], LatexifiedFunction]: +def expression( + **kwargs: Any, +) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]: ... def expression( fn: Callable[..., Any] | None = None, **kwargs: Any -) -> LatexifiedFunction | Callable[[Callable[..., Any]], LatexifiedFunction]: +) -> ipython_wrappers.LatexifiedFunction | Callable[ + [Callable[..., Any]], ipython_wrappers.LatexifiedFunction +]: """Attach LaTeX pretty-printing to the given function. This function is a shortcut for `latexify.function` with the default parameter `use_signature=False`. """ - kwargs["style"] = Style.EXPRESSION + kwargs["use_signature"] = kwargs.get("use_signature", False) if fn is not None: return function(fn, **kwargs) else: diff --git a/src/latexify/frontend_test.py b/src/latexify/frontend_test.py index 31e2bc3..e8425da 100644 --- a/src/latexify/frontend_test.py +++ b/src/latexify/frontend_test.py @@ -5,89 +5,6 @@ from latexify import frontend -def test_get_latex_identifiers() -> None: - def myfn(myvar): - return 3 * myvar - - identifiers = {"myfn": "f", "myvar": "x"} - - latex_without_flag = r"\mathrm{myfn}(\mathrm{myvar}) = 3 \cdot \mathrm{myvar}" - latex_with_flag = r"f(x) = 3 \cdot x" - - assert frontend.get_latex(myfn) == latex_without_flag - assert frontend.get_latex(myfn, identifiers=identifiers) == latex_with_flag - - -def test_get_latex_prefixes() -> None: - abc = object() - - def f(x): - return abc.d + x.y.z.e - - latex_without_flag = r"f(x) = \mathrm{abc}.d + x.y.z.e" - latex_with_flag1 = r"f(x) = d + x.y.z.e" - latex_with_flag2 = r"f(x) = \mathrm{abc}.d + y.z.e" - latex_with_flag3 = r"f(x) = \mathrm{abc}.d + z.e" - latex_with_flag4 = r"f(x) = d + e" - - assert frontend.get_latex(f) == latex_without_flag - assert frontend.get_latex(f, prefixes=set()) == latex_without_flag - assert frontend.get_latex(f, prefixes={"abc"}) == latex_with_flag1 - assert frontend.get_latex(f, prefixes={"x"}) == latex_with_flag2 - assert frontend.get_latex(f, prefixes={"x.y"}) == latex_with_flag3 - assert frontend.get_latex(f, prefixes={"abc", "x.y.z"}) == latex_with_flag4 - assert frontend.get_latex(f, prefixes={"abc", "x", "x.y.z"}) == latex_with_flag4 - - -def test_get_latex_reduce_assignments() -> None: - def f(x): - y = 3 * x - return y - - latex_without_flag = r"\begin{array}{l} y = 3 \cdot x \\ f(x) = y \end{array}" - latex_with_flag = r"f(x) = 3 \cdot x" - - assert frontend.get_latex(f) == latex_without_flag - assert frontend.get_latex(f, reduce_assignments=False) == latex_without_flag - assert frontend.get_latex(f, reduce_assignments=True) == latex_with_flag - - -def test_get_latex_use_math_symbols() -> None: - def f(alpha): - return alpha - - latex_without_flag = r"f(\mathrm{alpha}) = \mathrm{alpha}" - latex_with_flag = r"f(\alpha) = \alpha" - - assert frontend.get_latex(f) == latex_without_flag - assert frontend.get_latex(f, use_math_symbols=False) == latex_without_flag - assert frontend.get_latex(f, use_math_symbols=True) == latex_with_flag - - -def test_get_latex_use_signature() -> None: - def f(x): - return x - - latex_without_flag = "x" - latex_with_flag = r"f(x) = x" - - assert frontend.get_latex(f) == latex_with_flag - assert frontend.get_latex(f, use_signature=False) == latex_without_flag - assert frontend.get_latex(f, use_signature=True) == latex_with_flag - - -def test_get_latex_use_set_symbols() -> None: - def f(x, y): - return x & y - - latex_without_flag = r"f(x, y) = x \mathbin{\&} y" - latex_with_flag = r"f(x, y) = x \cap y" - - assert frontend.get_latex(f) == latex_without_flag - assert frontend.get_latex(f, use_set_symbols=False) == latex_without_flag - assert frontend.get_latex(f, use_set_symbols=True) == latex_with_flag - - def test_function() -> None: def f(x): return x diff --git a/src/latexify/generate_latex.py b/src/latexify/generate_latex.py new file mode 100644 index 0000000..252bebc --- /dev/null +++ b/src/latexify/generate_latex.py @@ -0,0 +1,77 @@ +"""Generate LaTeX code.""" + +from __future__ import annotations + +import enum +from collections.abc import Callable +from typing import Any + +from latexify import codegen +from latexify import config as cfg +from latexify import exceptions, parser, transformers + + +class Style(enum.Enum): + """The style of the generated LaTeX.""" + + ALGORITHMIC = "algorithmic" + EXPRESSION = "expression" + FUNCTION = "function" + IPYTHON_ALGORITHMIC = "ipython-algorithmic" + + +def get_latex( + fn: Callable[..., Any], + *, + style: Style = Style.FUNCTION, + config: cfg.Config | None = None, + **kwargs, +) -> str: + """Obtains LaTeX description from the function's source. + + Args: + fn: Reference to a function to analyze. + style: Style of the LaTeX description, the default is FUNCTION. + config: Use defined Config object, if it is None, it will be automatic assigned + with default value. + **kwargs: Dict of Config field values that could be defined individually + by users. + + Returns: + Generated LaTeX description. + + Raises: + latexify.exceptions.LatexifyError: Something went wrong during conversion. + """ + merged_config = cfg.Config.defaults().merge(config=config, **kwargs) + + # Obtains the source AST. + tree = parser.parse_function(fn) + + # Applies AST transformations. + if merged_config.prefixes is not None: + tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree) + if merged_config.identifiers is not None: + tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree) + if merged_config.reduce_assignments: + tree = transformers.AssignmentReducer().visit(tree) + if merged_config.expand_functions is not None: + tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree) + + # Generates LaTeX. + if style == Style.ALGORITHMIC: + return codegen.AlgorithmicCodegen( + use_math_symbols=merged_config.use_math_symbols, + use_set_symbols=merged_config.use_set_symbols, + ).visit(tree) + elif style == Style.IPYTHON_ALGORITHMIC: + # TODO(ZibingZhang): implement algorithmic codegen for ipython + raise exceptions.LatexifyNotSupportedError + else: + if style == Style.EXPRESSION: + kwargs["use_signature"] = kwargs.get("use_signature", False) + return codegen.FunctionCodegen( + use_math_symbols=merged_config.use_math_symbols, + use_signature=merged_config.use_signature, + use_set_symbols=merged_config.use_set_symbols, + ).visit(tree) diff --git a/src/latexify/generate_latex_test.py b/src/latexify/generate_latex_test.py new file mode 100644 index 0000000..e9df79e --- /dev/null +++ b/src/latexify/generate_latex_test.py @@ -0,0 +1,90 @@ +"""Tests for latexify.generate_latex.""" + +from __future__ import annotations + +from latexify import generate_latex + + +def test_get_latex_identifiers() -> None: + def myfn(myvar): + return 3 * myvar + + identifiers = {"myfn": "f", "myvar": "x"} + + latex_without_flag = r"\mathrm{myfn}(\mathrm{myvar}) = 3 \cdot \mathrm{myvar}" + latex_with_flag = r"f(x) = 3 \cdot x" + + assert generate_latex.get_latex(myfn) == latex_without_flag + assert generate_latex.get_latex(myfn, identifiers=identifiers) == latex_with_flag + + +def test_get_latex_prefixes() -> None: + abc = object() + + def f(x): + return abc.d + x.y.z.e + + latex_without_flag = r"f(x) = \mathrm{abc}.d + x.y.z.e" + latex_with_flag1 = r"f(x) = d + x.y.z.e" + latex_with_flag2 = r"f(x) = \mathrm{abc}.d + y.z.e" + latex_with_flag3 = r"f(x) = \mathrm{abc}.d + z.e" + latex_with_flag4 = r"f(x) = d + e" + + assert generate_latex.get_latex(f) == latex_without_flag + assert generate_latex.get_latex(f, prefixes=set()) == latex_without_flag + assert generate_latex.get_latex(f, prefixes={"abc"}) == latex_with_flag1 + assert generate_latex.get_latex(f, prefixes={"x"}) == latex_with_flag2 + assert generate_latex.get_latex(f, prefixes={"x.y"}) == latex_with_flag3 + assert generate_latex.get_latex(f, prefixes={"abc", "x.y.z"}) == latex_with_flag4 + assert ( + generate_latex.get_latex(f, prefixes={"abc", "x", "x.y.z"}) == latex_with_flag4 + ) + + +def test_get_latex_reduce_assignments() -> None: + def f(x): + y = 3 * x + return y + + latex_without_flag = r"\begin{array}{l} y = 3 \cdot x \\ f(x) = y \end{array}" + latex_with_flag = r"f(x) = 3 \cdot x" + + assert generate_latex.get_latex(f) == latex_without_flag + assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag + assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag + + +def test_get_latex_use_math_symbols() -> None: + def f(alpha): + return alpha + + latex_without_flag = r"f(\mathrm{alpha}) = \mathrm{alpha}" + latex_with_flag = r"f(\alpha) = \alpha" + + assert generate_latex.get_latex(f) == latex_without_flag + assert generate_latex.get_latex(f, use_math_symbols=False) == latex_without_flag + assert generate_latex.get_latex(f, use_math_symbols=True) == latex_with_flag + + +def test_get_latex_use_signature() -> None: + def f(x): + return x + + latex_without_flag = "x" + latex_with_flag = r"f(x) = x" + + assert generate_latex.get_latex(f) == latex_with_flag + assert generate_latex.get_latex(f, use_signature=False) == latex_without_flag + assert generate_latex.get_latex(f, use_signature=True) == latex_with_flag + + +def test_get_latex_use_set_symbols() -> None: + def f(x, y): + return x & y + + latex_without_flag = r"f(x, y) = x \mathbin{\&} y" + latex_with_flag = r"f(x, y) = x \cap y" + + assert generate_latex.get_latex(f) == latex_without_flag + assert generate_latex.get_latex(f, use_set_symbols=False) == latex_without_flag + assert generate_latex.get_latex(f, use_set_symbols=True) == latex_with_flag diff --git a/src/latexify/ipython_wrappers.py b/src/latexify/ipython_wrappers.py new file mode 100644 index 0000000..88aa015 --- /dev/null +++ b/src/latexify/ipython_wrappers.py @@ -0,0 +1,139 @@ +"""Wrapper objects for IPython to display output.""" + +from __future__ import annotations + +import abc +from typing import Any, Callable, cast + +from latexify import exceptions, generate_latex + + +class LatexifiedRepr(metaclass=abc.ABCMeta): + """Object with LaTeX representation.""" + + _fn: Callable[..., Any] + + def __init__(self, fn: Callable[..., Any], **kwargs) -> None: + self._fn = fn + + @property + def __doc__(self) -> str | None: + return self._fn.__doc__ + + @__doc__.setter + def __doc__(self, val: str | None) -> None: + self._fn.__doc__ = val + + @property + def __name__(self) -> str: + return self._fn.__name__ + + @__name__.setter + def __name__(self, val: str) -> None: + self._fn.__name__ = val + + # After Python 3.7 + # @final + def __call__(self, *args) -> Any: + return self._fn(*args) + + @abc.abstractmethod + def __str__(self) -> str: + ... + + @abc.abstractmethod + def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None: + """IPython hook to display HTML visualization.""" + ... + + @abc.abstractmethod + def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None: + """IPython hook to display LaTeX visualization.""" + ... + + +class LatexifiedAlgorithm(LatexifiedRepr): + """Algorithm with latex representation.""" + + _latex: str | None + _error: str | None + _ipython_latex: str | None + _ipython_error: str | None + + def __init__(self, fn: Callable[..., Any], **kwargs) -> None: + super().__init__(fn) + + try: + self._latex = generate_latex.get_latex( + fn, style=generate_latex.Style.ALGORITHMIC, **kwargs + ) + self._error = None + except exceptions.LatexifyError as e: + self._latex = None + self._error = f"{type(e).__name__}: {str(e)}" + + try: + self._ipython_latex = generate_latex.get_latex( + fn, style=generate_latex.Style.IPYTHON_ALGORITHMIC, **kwargs + ) + self._ipython_error = None + except exceptions.LatexifyError as e: + self._ipython_latex = None + self._ipython_error = f"{type(e).__name__}: {str(e)}" + + def __str__(self) -> str: + return self._latex if self._latex is not None else cast(str, self._error) + + def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None: + """IPython hook to display HTML visualization.""" + return ( + '' + self._ipython_error + "" + if self._ipython_error is not None + else None + ) + + def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None: + """IPython hook to display LaTeX visualization.""" + return ( + r"$ " + self._ipython_latex + " $" + if self._ipython_latex is not None + else self._ipython_error + ) + + +class LatexifiedFunction(LatexifiedRepr): + """Function with latex representation.""" + + _latex: str | None + _error: str | None + + def __init__(self, fn: Callable[..., Any], **kwargs) -> None: + super().__init__(fn, **kwargs) + + try: + self._latex = self._latex = generate_latex.get_latex( + fn, style=generate_latex.Style.FUNCTION, **kwargs + ) + self._error = None + except exceptions.LatexifyError as e: + self._latex = None + self._error = f"{type(e).__name__}: {str(e)}" + + def __str__(self) -> str: + return self._latex if self._latex is not None else cast(str, self._error) + + def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None: + """IPython hook to display HTML visualization.""" + return ( + '' + self._error + "" + if self._error is not None + else None + ) + + def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None: + """IPython hook to display LaTeX visualization.""" + return ( + r"$$ \displaystyle " + self._latex + " $$" + if self._latex is not None + else self._error + )