Skip to content

Commit

Permalink
Fix variance issue with Par generic.
Browse files Browse the repository at this point in the history
  • Loading branch information
MylesBartlett committed Aug 1, 2024
1 parent f3d6eed commit 01fd8c9
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 83 deletions.
4 changes: 2 additions & 2 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"serox/**"
],
"exclude": [
"**/.",
"**/__pycache__",
"**/node_modules",
"**/.undodir",
".venv",
"**/.cache"
],
Expand All @@ -18,7 +18,7 @@
"reportUnusedCallResult": "error",
"reportUnnecessaryTypeIgnoreComment": "warning",
"reportMissingSuperCall": "warning",
"reportImportCycles": "error",
"reportImportCycles": "none",
"reportShadowedImports": "warning",
"reportUninitializedInstanceVariable": "error",
"reportPropertyTypeMismatch": "error",
Expand Down
11 changes: 6 additions & 5 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false

-e file:.
black==24.4.2
Expand Down Expand Up @@ -44,21 +45,21 @@ packaging==24.1
# via pytest
pathspec==0.12.1
# via black
pip==24.1.2
pip==24.2
platformdirs==4.2.2
# via black
# via virtualenv
pluggy==1.5.0
# via pytest
pre-commit==3.7.1
pre-commit==3.8.0
pydoclint==0.5.6
pyright==1.1.372
pytest==8.3.1
pyright==1.1.374
pytest==8.3.2
# via pytest-cov
pytest-cov==5.0.0
pyyaml==6.0.1
# via pre-commit
ruff==0.5.4
ruff==0.5.5
typing-extensions==4.12.2
# via serox
virtualenv==20.26.3
Expand Down
1 change: 1 addition & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false

-e file:.
joblib==1.4.2
Expand Down
7 changes: 4 additions & 3 deletions serox/collections/hash_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Any, Generator, Hashable, Iterable, Literal, Sized, override

from serox.common import False_, True_
from serox.conftest import TESTING
from serox.convert import Into
from serox.default import Default
Expand Down Expand Up @@ -123,7 +124,7 @@ def clone(self) -> HashMap[K, V]:


@dataclass(repr=True, init=False)
class Keys[K, P: bool](Iterator[K, P]):
class Keys[K, P: (True_, False_)](Iterator[K, P]):
def __init__(self, inner: HashMap[K, Any], /, par: P) -> None:
super().__init__()
self.iter = iter(inner.inner.keys())
Expand All @@ -138,7 +139,7 @@ def next(self) -> Option[K]:


@dataclass(repr=True, init=False)
class Values[V, P: bool](Iterator[V, P]):
class Values[V, P: (True_, False_)](Iterator[V, P]):
def __init__(self, inner: HashMap[Any, V], /, par: P) -> None:
super().__init__()
self.iter = iter(inner.inner.values())
Expand All @@ -153,7 +154,7 @@ def next(self) -> Option[V]:


@dataclass(repr=True, init=False)
class Entries[K, V, P: bool](Iterator[Entry[K, V], P]):
class Entries[K, V, P: (True_, False_)](Iterator[Entry[K, V], P]):
def __init__(self, inner: HashMap[K, V], /, par: P) -> None:
super().__init__()
self.iter = iter(inner.inner.items())
Expand Down
29 changes: 18 additions & 11 deletions serox/collections/hash_set.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generator, Hashable, Iterable, Literal, Self, Sized, override
from typing import Any, Generator, Hashable, Iterable, Self, Sized, override
from typing import Iterator as NativeIterator

from serox.common import False_, True_
from serox.convert import Into
from serox.default import Default
from serox.iter import Extend, FromIterator, IntoIterator, IntoParIterator, Iterator
Expand Down Expand Up @@ -125,27 +127,32 @@ def from_iter(cls, iter: Iterable[T], /) -> HashSet[T]:
return HashSet(*iter)

@override
def iter(self) -> Iter[T, Literal[False]]:
return Iter(self, par=False)
def iter(self) -> Iter[T, False_]:
return Iter.new(self, par=False)

def __iter__(self) -> Generator[T, None, None]:
yield from self.iter()

@override
def par_iter(self) -> Iter[T, Literal[True]]:
return Iter(self, par=True)
def par_iter(self) -> Iter[T, True_]:
return Iter.new(self, par=True)

@override
def clone(self) -> HashSet[T]:
return HashSet(*self.inner.copy())


@dataclass(repr=True, init=False)
class Iter[Item, P: bool](Iterator[Item, P]):
def __init__(self, inner: HashSet[Item], /, par: P) -> None:
super().__init__()
self.iter = iter(inner.inner)
self.par = par
@dataclass(repr=True, frozen=True, kw_only=True)
class Iter[Item, Par: (True_, False_)](Iterator[Item, Par]):
iter: NativeIterator[Item]
par: Par

@classmethod
def new[Item2, Par2: (True_, False_)](
cls, data: HashSet[Item2], par: Par2 = True
) -> Iter[Item2, Par2]:
iter_ = iter(data.inner)
return Iter(iter=iter_, par=par)

@override
def next(self) -> Option[Item]:
Expand Down
10 changes: 10 additions & 0 deletions serox/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations
from typing import Literal

__all__ = [
"False_",
"True_",
]

type True_ = Literal[True]
type False_ = Literal[False]
84 changes: 60 additions & 24 deletions serox/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Callable,
Generator,
Iterable,
Literal,
Protocol,
Self,
cast,
Expand All @@ -25,6 +24,8 @@
)

from serox.cmp import Ord
from serox.common import False_, True_
from serox.conftest import TESTING
from serox.misc import SelfAddable, SelfMultiplicable

if TYPE_CHECKING:
Expand All @@ -50,20 +51,21 @@
"Zip",
]


type Fn1[T, U] = Callable[[T], U]


class FromIterator[A](Protocol):
@classmethod
def from_iter[P: bool](cls, iter: Iterator[A, P], /) -> Self: ...
def from_iter[P: (True_, False_)](cls, iter: Iterator[A, P], /) -> Self: ...


def _identity[T](x: T) -> T:
return x


@runtime_checkable
class Iterator[Item, Par: bool](Protocol):
class Iterator[Item, Par: (True_, False_)](Protocol):
par: Par

def next(self) -> Option[Item]: ...
Expand Down Expand Up @@ -195,9 +197,9 @@ def max[U: Ord](self: Iterator[U, Par]) -> U:
def min[U: Ord](self: Iterator[U, Par]) -> U:
return min(self)

def par_bridge(self) -> Iterator[Item, Literal[True]]:
def par_bridge(self) -> Iterator[Item, True_]:
object.__setattr__(self, "par", True)
return cast(Iterator[Item, Literal[True]], self)
return cast(Iterator[Item, True_], self)


class Chunk[Item](list[Item], FromIterator[Item]):
Expand All @@ -214,7 +216,7 @@ def is_empty(self) -> bool:


@dataclass
class ArrayChunk[Item, P: bool](Iterator[Chunk[Item], P]):
class ArrayChunk[Item, P: (True_, False_)](Iterator[Chunk[Item], P]):
iter: Iterator[Item, P]
n: int
par: P
Expand All @@ -231,22 +233,22 @@ def next(self) -> Option[Chunk[Item]]:


class IntoIterator[T](Protocol):
def iter(self) -> Iterator[T, Literal[False]]: ...
def iter(self) -> Iterator[T, False_]: ...


class IntoParIterator[T](Protocol):
def par_iter(self) -> Iterator[T, Literal[True]]: ...
def par_iter(self) -> Iterator[T, True_]: ...


class DoubleEndedIterator[Item, P: bool](Iterator[Item, P], Protocol):
class DoubleEndedIterator[Item, P: (True_, False_)](Iterator[Item, P], Protocol):
def next_back(self) -> Option[Item]: ...

def rev(self) -> Rev[Item, P]:
return Rev(self, par=self.par)


@dataclass(repr=True)
class Filter[Item, P: bool](Iterator[Item, P]):
class Filter[Item, P: (True_, False_)](Iterator[Item, P]):
iter: Iterator[Item, P]
f: Fn1[Item, bool]
par: P
Expand All @@ -265,7 +267,7 @@ def next(self) -> Option[Item]:


@dataclass(repr=True)
class FilterMap[Item, B, P: bool](Iterator[B, P]):
class FilterMap[Item, B, P: (True_, False_)](Iterator[B, P]):
iter: Iterator[Item, P]
f: Fn1[Item, Option[B]]
par: P
Expand All @@ -286,7 +288,7 @@ def next(self) -> Option[B]:


@dataclass(repr=True)
class Map[Item, B, P: bool](Iterator[B, P]):
class Map[Item, B, P: (True_, False_)](Iterator[B, P]):
iter: Iterator[Item, P]
f: Fn1[Item, B]
par: P
Expand All @@ -297,7 +299,7 @@ def next(self) -> Option[B]:


@dataclass(repr=True)
class Take[Item, P: bool](Iterator[Item, P]):
class Take[Item, P: (True_, False_)](Iterator[Item, P]):
iter: Iterator[Item, P]
_n: int
par: P
Expand Down Expand Up @@ -325,7 +327,7 @@ def nth(self, n: int) -> Option[Item]:


@dataclass(repr=True)
class TakeWhile[Item, P: bool](Iterator[Item, P]):
class TakeWhile[Item, P: (True_, False_)](Iterator[Item, P]):
iter: Iterator[Item, P]
predicate: Fn1[Item, bool]
par: P
Expand All @@ -349,7 +351,7 @@ def next(self) -> Option[Item]:


@dataclass(repr=True)
class Zip[A, B, P: bool](Iterator[tuple[A, B], P]):
class Zip[A, B, P: (True_, False_)](Iterator[tuple[A, B], P]):
a: Iterator[A, P]
b: Iterator[B, P]
par: P
Expand All @@ -371,7 +373,7 @@ def next(self) -> Option[tuple[A, B]]:

# Parametrising the first generic of `Iterator` as `Any` to avoid a circular import.
@dataclass(repr=True)
class ZipLongest[A, B, P: bool](Iterator[Any, P]):
class ZipLongest[A, B, P: (True_, False_)](Iterator[Any, P]):
a: Iterator[A, P]
b: Iterator[B, P]
par: P
Expand All @@ -393,7 +395,7 @@ def next(self) -> Option[tuple[A, B] | tuple[Null[A], B] | tuple[A, Null[B]]]:


@dataclass(repr=True)
class Chain[A, P: bool](Iterator[A, P]):
class Chain[A, P: (True_, False_)](Iterator[A, P]):
a: Iterator[A, P]
b: Iterator[A, P]
par: P
Expand All @@ -410,7 +412,7 @@ def next(self) -> Option[A]:


@dataclass(repr=True)
class Rev[Item, P: bool](Iterator[Item, P]):
class Rev[Item, P: (True_, False_)](Iterator[Item, P]):
iter: DoubleEndedIterator[Item, P]
par: P

Expand All @@ -426,12 +428,27 @@ def extend_one(self, item: Item) -> None:
self.extend(Some(item))


@dataclass(repr=True, init=False)
class Bridge[Item, P: bool](Iterator[Item, P]):
def __init__(self, iter: NativeIterator[Item], par: P = False) -> None:
super().__init__()
self.iter = iter
self.par = par
@dataclass(repr=True, frozen=True, kw_only=True)
class Bridge[Item, Par: (True_, False_)](Iterator[Item, Par]):
"""
A bridge between native Python iterators and `serox` ones.
Can be parallel (`par = True`) or non-parallel (`par = False`).
"""

iter: NativeIterator[Item]
"""The native Python iterator being bridged."""
par: Par
"""Whether to parallelise the iterator."""

@classmethod
def new[Item2, Par2: (True_, False_)](
cls, iter: NativeIterator[Item2], /, par: Par2 = True
) -> Bridge[Item2, Par2]:
return Bridge(iter=iter, par=par)

@classmethod
def par_new[Item2](cls, iter: NativeIterator[Item2], /) -> Bridge[Item2, True_]:
return Bridge(iter=iter, par=True)

@override
def next(self) -> Option[Item]:
Expand All @@ -441,3 +458,22 @@ def next(self) -> Option[Item]:
return Some(self.iter.__next__())
except StopIteration:
return Null()


if TESTING:

def test_par_invariance():
from .collections import HashMap
from .vec import Vec

values = Vec(*range(4))
keys = ["foo", "bar", "baz"]
bridge = Bridge.new(iter(keys), par=False)
bridge = Bridge(iter=iter(keys), par=False)
mapped = values.iter().map(lambda x: x**2)
_ = bridge.zip(mapped).collect(HashMap[str, int])

bridge = Bridge.new(iter(keys), par=True)
# shouldn't be able to combine parallel iterators with non-parallel ones
# for consistent typing
_ = bridge.zip(values.iter()) # pyright: ignore[reportArgumentType]
3 changes: 2 additions & 1 deletion serox/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
override,
)

from serox.common import False_, True_
from serox.convert import From, Into
from serox.default import Default
from serox.iter import DoubleEndedIterator, IntoIterator
Expand Down Expand Up @@ -301,7 +302,7 @@ def is_null[T](x: Option[T], /) -> TypeGuard[Null[T]]:
repr=True,
slots=True,
)
class Iter[Item, P: bool](DoubleEndedIterator[Item, P]):
class Iter[Item, P: (True_, False_)](DoubleEndedIterator[Item, P]):
item: Option[Item]
par: P

Expand Down
Loading

0 comments on commit 01fd8c9

Please sign in to comment.