Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add typing overloads for passing file BinaryIO #136

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions slim_trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import importlib.metadata
import warnings
from pathlib import Path
from typing import Any, BinaryIO, Optional, Union
from typing import Any, BinaryIO, Optional, Union, overload

from slim_trees.pickling import (
dump_compressed,
Expand Down Expand Up @@ -41,6 +41,22 @@
]


@overload
def dump_sklearn_compressed(
model: Any,
file: BinaryIO,
compression: Union[str, dict],
): ...


@overload
def dump_sklearn_compressed(
model: Any,
file: Union[str, Path],
compression: Optional[Union[str, dict]] = None,
): ...


def dump_sklearn_compressed(
model: Any,
file: Union[str, Path, BinaryIO],
Expand All @@ -59,7 +75,7 @@ def dump_sklearn_compressed(
"""
from slim_trees.sklearn_tree import dump

dump_compressed(model, file, compression, dump)
dump_compressed(model, file, compression, dump) # type: ignore


def dumps_sklearn_compressed(
Expand All @@ -80,6 +96,22 @@ def dumps_sklearn_compressed(
return dumps_compressed(model, compression, dumps)


@overload
def dump_lgbm_compressed(
model: Any,
file: BinaryIO,
compression: Union[str, dict],
): ...


@overload
def dump_lgbm_compressed(
model: Any,
file: Union[str, Path],
compression: Optional[Union[str, dict]] = None,
): ...


def dump_lgbm_compressed(
model: Any,
file: Union[str, Path, BinaryIO],
Expand All @@ -98,7 +130,7 @@ def dump_lgbm_compressed(
"""
from slim_trees.lgbm_booster import dump

dump_compressed(model, file, compression, dump)
dump_compressed(model, file, compression, dump) # type: ignore


def dumps_lgbm_compressed(
Expand Down
36 changes: 35 additions & 1 deletion slim_trees/pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import pickle
from collections.abc import Callable
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union, overload


class _NoCompression:
Expand Down Expand Up @@ -68,6 +68,24 @@ def _unpack_compression_args(
raise ValueError("File must be a path or compression must not be None.")


@overload
def dump_compressed(
obj: Any,
file: BinaryIO,
compression: Union[str, dict],
dump_function: Optional[Callable] = None,
): ...


@overload
def dump_compressed(
obj: Any,
file: Union[str, pathlib.Path],
compression: Optional[Union[str, dict]] = None,
dump_function: Optional[Callable] = None,
): ...


def dump_compressed(
obj: Any,
file: Union[str, pathlib.Path, BinaryIO],
Expand Down Expand Up @@ -123,6 +141,22 @@ def dumps_compressed(
return _get_compression_library(compression_method).compress(data_uncompressed)


@overload
def load_compressed(
file: BinaryIO,
compression: Union[str, dict],
unpickler_class: type = pickle.Unpickler,
): ...


@overload
def load_compressed(
file: Union[str, pathlib.Path],
compression: Optional[Union[str, dict]] = None,
unpickler_class: type = pickle.Unpickler,
): ...


def load_compressed(
file: Union[str, pathlib.Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
Expand Down