Skip to content

Commit

Permalink
TYP: adopt based{mypy, pyright} (#37)
Browse files Browse the repository at this point in the history
* TYP: adopt based{mypy, pyright}

* fix link in comments

* DEV: update lockfile

* address review comments

* update lockfile

* ignore missing xp-strict stubs

* update docs
  • Loading branch information
lucascolley authored Nov 26, 2024
1 parent c1d0e20 commit 1a33939
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 298 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ jobs:
with:
pixi-version: v0.37.0
cache: true
- name: Run Pylint & Mypy
- name: Run Pylint, Mypy & Pyright
run: |
pixi run -e lint pylint
pixi run -e lint mypy
pixi run -e lint pyright
checks:
name: Check ${{ matrix.environment }}
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@

Extra array functions built on top of the array API standard.

Used by:

- https://github.com/scipy/scipy
- ...

## Contributors

<!-- ALL-CONTRIBUTORS-LIST:START - Do not remove or modify this section -->
Expand Down
2 changes: 2 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
comment: false
ignore:
- "src/array_api_extra/_typing"
1 change: 1 addition & 0 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pixi run ipython
pixi run pre-commit
pixi run pylint
pixi run mypy
pixi run pyright
```

Alternative environments are available with a subset of the dependencies and
Expand Down
528 changes: 246 additions & 282 deletions pixi.lock

Large diffs are not rendered by default.

24 changes: 22 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,23 @@ array-api-extra = { path = ".", editable = true }

[tool.pixi.feature.lint.dependencies]
pre-commit = "*"
mypy = "*"
pylint = "*"
# import dependencies for mypy:
array-api-strict = "*"
numpy = "*"
pytest = "*"

[tool.pixi.feature.lint.pypi-dependencies]
basedmypy = { version = "*", extras = ["faster-cache"] }
basedpyright = "*"

[tool.pixi.feature.lint.tasks]
pre-commit-install = { cmd = "pre-commit install" }
pre-commit = { cmd = "pre-commit run -v --all-files --show-diff-on-failure" }
mypy = { cmd = "mypy", cwd = "." }
pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" }
lint = { depends-on = ["pre-commit", "pylint", "mypy"] }
pyright = { cmd = "basedpyright", cwd = "." }
lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] }

[tool.pixi.feature.tests.dependencies]
pytest = ">=6"
Expand Down Expand Up @@ -165,13 +169,29 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_unreachable = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
# data-apis/array-api#589
disallow_any_expr = false

[[tool.mypy.overrides]]
module = "array_api_extra.*"
disallow_untyped_defs = true
disallow_incomplete_defs = true


# pyright

[tool.basedpyright]
include = ["src", "tests"]
pythonVersion = "3.10"
pythonPlatform = "All"
typeCheckingMode = "all"

# data-apis/array-api#589
reportAny = false
reportExplicitAny = false
reportUnknownMemberType = false


# Ruff

[tool.ruff]
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import typing
import warnings
from typing import TYPE_CHECKING

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]
Expand Down
3 changes: 2 additions & 1 deletion src/array_api_extra/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from types import ModuleType
from typing import Any

Array = Any # To be changed to a Protocol later (see array-api#589)
# To be changed to a Protocol later (see data-apis/array-api#589)
Array = Any # type: ignore[no-any-explicit]

__all__ = ["Array", "ModuleType"]
20 changes: 10 additions & 10 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from __future__ import annotations

import contextlib
import typing
import warnings
from typing import TYPE_CHECKING, Any

# array-api-strict#6
import array_api_strict as xp # type: ignore[import-untyped]
import array_api_strict as xp # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc

if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)
if typing.TYPE_CHECKING:
from array_api_extra._typing import Array


class TestAtLeastND:
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_1d(self):

@pytest.mark.parametrize("n", range(1, 10))
@pytest.mark.parametrize("offset", range(1, 10))
def test_create_diagonal(self, n, offset):
def test_create_diagonal(self, n: int, offset: int):
# from scipy._lib tests
rng = np.random.default_rng(2347823)
one = xp.asarray(1.0)
Expand Down Expand Up @@ -180,9 +180,9 @@ def test_basic(self):
assert_array_equal(kron(a, b, xp=xp), k)

def test_kron_smoke(self):
a = xp.ones([3, 3])
b = xp.ones([3, 3])
k = xp.ones([9, 9])
a = xp.ones((3, 3))
b = xp.ones((3, 3))
k = xp.ones((9, 9))

assert_array_equal(kron(a, b, xp=xp), k)

Expand All @@ -197,7 +197,7 @@ def test_kron_smoke(self):
((2, 0, 0, 2), (2, 0, 2)),
],
)
def test_kron_shape(self, shape_a, shape_b):
def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
a = xp.ones(shape_a)
b = xp.ones(shape_b)
normalised_shape_a = xp.asarray(
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_simple(self):
assert_allclose(w, xp.flip(w, axis=0))

@pytest.mark.parametrize("x", [0, 1 + 3j])
def test_dtype(self, x):
def test_dtype(self, x: int | complex):
with pytest.raises(ValueError, match="real floating data type"):
sinc(xp.asarray(x), xp=xp)

Expand Down

0 comments on commit 1a33939

Please sign in to comment.