diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0803a15b..7ed60026 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,15 +3,17 @@ on: push: branches: - main -permissions: - contents: write + pull_request: + branches: + - main + jobs: - deploy: + lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: - python-version: 3.x - - run: pip install -r docs/requirements.txt - - run: mkdocs gh-deploy --force + python-version: '3.10' + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..2ceaf2dd --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,20 @@ +name: docs +on: + push: + branches: + - main + +permissions: + contents: write + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + with: + python-version: '3.10' + - run: pip install -r docs/requirements.txt + - run: mkdocs gh-deploy --force diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e69de29b..ab6bac1e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +# Install the pre-commit hooks below with +# 'pre-commit install' + +# Auto-update the version of the hooks with +# 'pre-commit autoupdate' + +# Run the hooks on all files with +# 'pre-commit run --all' + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0 + hooks: + - id: check-ast + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + # only include python files + files: \.py$ + - id: debug-statements + # only include python files + files: \.py$ + - id: trailing-whitespace + # only include python files + files: \.py$ + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1 + hooks: + - id: ruff diff --git a/examples/block_map.py b/examples/block_map.py index 6ad922ad..8877a1d1 100644 --- a/examples/block_map.py +++ b/examples/block_map.py @@ -18,7 +18,6 @@ import jax import jax.numpy as jnp -from jax import lax from jax import random import jax_triton as jt from jax_triton import pallas as pl @@ -178,8 +177,8 @@ def main(unused_argv): k = random.normal(k_key, shape, dtype=dtype) v = random.normal(v_key, shape, dtype=dtype) - o = mha(q, k, v).block_until_ready() - o_ref = mha_reference(q, k, v).block_until_ready() + mha(q, k, v).block_until_ready() + mha_reference(q, k, v).block_until_ready() if __name__ == "__main__": from absl import app diff --git a/examples/pallas/blocksparse_matmul.py b/examples/pallas/blocksparse_matmul.py index 67abd1e7..a7f08239 100644 --- a/examples/pallas/blocksparse_matmul.py +++ b/examples/pallas/blocksparse_matmul.py @@ -21,7 +21,6 @@ from jax import random import jax from jax import lax -import jax.numpy as jnp import numpy as np import jax_triton as jt @@ -87,7 +86,7 @@ def tree_unflatten(cls, data, xs): return BlockELL(blocks, blocks_per_row, indices, shape=shape) def _validate(self): - nblocks, n, m = self.blocks.shape + _nblocks, n, m = self.blocks.shape nrows = self.blocks_per_row.shape[0] assert self.indices.shape[0] == nrows assert len(self.shape) == 2 @@ -168,7 +167,7 @@ def sdd_matmul(x_ell, y, num_warps: int = 8, num_stages: int = 3, bn: int = 64, grid = (jt.cdiv(m, bm), jt.cdiv(n, bn)) kernel = functools.partial(sdd_kernel, bm=bm, bn=bn) - out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=x.dtype) + out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=x_ell.dtype) return pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages, grid=grid, out_shape=out_shape, debug=debug)(x_ell.blocks, x_ell.indices, diff --git a/examples/pallas/lstm.py b/examples/pallas/lstm.py index 814f34fd..7eae81a5 100644 --- a/examples/pallas/lstm.py +++ b/examples/pallas/lstm.py @@ -17,14 +17,11 @@ import functools import timeit -from typing import Optional, Tuple import jax.numpy as jnp from jax import random import jax -from jax import lax from jax._src.lax.control_flow import for_loop -import jax.numpy as jnp import numpy as np import jax_triton as jt @@ -188,13 +185,15 @@ def main(unused_argv): x = random.normal(x_key, (batch_size, feature_size), dtype) h = random.normal(h_key, (batch_size, hidden_size), dtype) c = random.normal(c_key, (batch_size, hidden_size), dtype) - lstm_cell = jax.jit(functools.partial(lstm_cell, - block_batch=block_batch, - block_hidden=block_hidden, - block_features=block_features, - num_warps=num_warps, - num_stages=num_stages)) - y, c_next = jax.block_until_ready(lstm_cell(weights, x, h, c)) + lstm_cell_fn = jax.jit(functools.partial( + lstm_cell, + block_batch=block_batch, + block_hidden=block_hidden, + block_features=block_features, + num_warps=num_warps, + num_stages=num_stages, + )) + y, c_next = jax.block_until_ready(lstm_cell_fn(weights, x, h, c)) y_ref, c_next_ref = lstm_cell_reference(weights, x, h, c) np.testing.assert_allclose(y, y_ref, atol=0.05, rtol=0.05) np.testing.assert_allclose(c_next, c_next_ref, atol=0.05, rtol=0.05) diff --git a/jax_triton/__init__.py b/jax_triton/__init__.py index a923c880..bfba86a0 100644 --- a/jax_triton/__init__.py +++ b/jax_triton/__init__.py @@ -13,6 +13,17 @@ # limitations under the License. """Library for JAX-Triton integrations.""" + +__all__ = [ + "utils", + "triton_call", + "cdiv", + "next_power_of_2", + "strides_from_shape", + "__version__", + "__version_info__", +] + import jaxlib from jax._src.lib import gpu_triton from jax_triton import utils diff --git a/jax_triton/experimental/fusion/__init__.py b/jax_triton/experimental/fusion/__init__.py index be06f43c..cfbe03d2 100644 --- a/jax_triton/experimental/fusion/__init__.py +++ b/jax_triton/experimental/fusion/__init__.py @@ -19,4 +19,4 @@ jax.nn.sigmoid = sigmoid del sigmoid, oryx, jax -from jax_triton.experimental.fusion.lowering import jit +from jax_triton.experimental.fusion.lowering import jit as jit diff --git a/jax_triton/experimental/fusion/fusion.py b/jax_triton/experimental/fusion/fusion.py index 6d9a0ca5..c6e73e1b 100644 --- a/jax_triton/experimental/fusion/fusion.py +++ b/jax_triton/experimental/fusion/fusion.py @@ -23,7 +23,6 @@ from jax import lax from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe -from jax.interpreters import xla from jax._src import core from jax._src import util from jax._src.lax.control_flow import for_loop @@ -251,6 +250,7 @@ def _matmul_elementwise_lowering_rule(x, y, *args, left_ops, right_ops, out_ops, bias, = args else: bias = None + del bias # TODO(sharadmv): Please fix or remove `bias` above. lhs_dim, rhs_dim = contract_dims M, N, K = x.shape[1 - lhs_dim], y.shape[1 - rhs_dim], x.shape[lhs_dim] assert x.shape[lhs_dim] == y.shape[rhs_dim] @@ -340,4 +340,3 @@ def _dot_general_lowering_rule(x, y, dimension_numbers, **_): out_ops=[], contract_dims=(lhs_dim, rhs_dim)) lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule - diff --git a/jax_triton/experimental/fusion/jaxpr_rewriter.py b/jax_triton/experimental/fusion/jaxpr_rewriter.py index 8c0ea272..aa0277b4 100644 --- a/jax_triton/experimental/fusion/jaxpr_rewriter.py +++ b/jax_triton/experimental/fusion/jaxpr_rewriter.py @@ -19,7 +19,7 @@ import dataclasses import itertools as it -from typing import Any, Callable, Dict, List, Set, Tuple, Union +from typing import Any, Callable, List, Tuple, Union from jax._src import core as jax_core import jax.numpy as jnp diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index c20bcaa6..dc542e73 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -336,8 +336,6 @@ def compile_ttir_to_hsaco_inplace( amdgcn = hip_backend.make_amdgcn(llir, metadata, hip_options) hsaco = hip_backend.make_hsaco(amdgcn, metadata, hip_options) - if hip_options.debug: - print(x) name = metadata["name"] ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None llir = str(llir) if _JAX_TRITON_DUMP_DIR else None diff --git a/jax_triton/utils.py b/jax_triton/utils.py index 24206028..ee8a6277 100644 --- a/jax_triton/utils.py +++ b/jax_triton/utils.py @@ -13,6 +13,11 @@ # limitations under the License. """Contains utilities for writing and calling Triton functions.""" + + +__all__ = ["cdiv", "strides_from_shape", "next_power_of_2"] + + from jax.experimental.pallas import cdiv from jax.experimental.pallas import strides_from_shape -from jax.experimental.pallas import next_power_of_2 \ No newline at end of file +from jax.experimental.pallas import next_power_of_2 diff --git a/pyproject.toml b/pyproject.toml index 0b5f6ca2..59c280d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,3 +26,25 @@ packages = ["jax_triton"] [tool.setuptools.dynamic] version = {attr = "jax_triton.version.__version__"} + +[tool.ruff] +preview = true +exclude = [ + ".git", + "build", + "__pycache__", + "*.ipynb", +] +line-length = 88 +indent-width = 2 +target-version = "py310" + +[tool.ruff.lint] +ignore = [ + # Do not assign a `lambda` expression, use a `def` + "E731", + # Module level import not at top of file + "E402", + # Ambiguous variable name + "E741", +]