Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for loading / writing from files #134

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "slim-trees"
description = "A python package for efficient pickling of ML models."
version = "0.2.11"
version = "0.2.12"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.8"
Expand Down
18 changes: 11 additions & 7 deletions slim_trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import importlib.metadata
import warnings
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, BinaryIO, Optional, Union

from slim_trees.pickling import (
dump_compressed,
Expand Down Expand Up @@ -42,22 +42,24 @@


def dump_sklearn_compressed(
model: Any, path: Union[str, Path], compression: Optional[Union[str, dict]] = None
model: Any,
file: Union[str, Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
):
"""
Pickles a model and saves a compressed version to the disk.

Saves the parameters of the model as int16 and float32 instead of int64 and float64.
:param model: the model to save
:param path: where to save the model
:param file: where to save the model, either a path or a file object
:param compression: the compression method used. Either a string or a dict with key 'method' set
to the compression method and other key-value pairs are forwarded to `open`
of the compression library.
Options: ["no", "lzma", "gzip", "bz2"]
"""
from slim_trees.sklearn_tree import dump

dump_compressed(model, path, compression, dump)
dump_compressed(model, file, compression, dump)


def dumps_sklearn_compressed(
Expand All @@ -79,22 +81,24 @@ def dumps_sklearn_compressed(


def dump_lgbm_compressed(
model: Any, path: Union[str, Path], compression: Optional[Union[str, dict]] = None
model: Any,
file: Union[str, Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
):
"""
Pickles a model and saves a compressed version to the disk.

Saves the parameters of the model as int16 and float32 instead of int64 and float64.
:param model: the model to save
:param path: where to save the model
:param file: where to save the model, either a path or a file object
:param compression: the compression method used. Either a string or a dict with key 'method' set
to the compression method and other key-value pairs are forwarded to `open`
of the compression library.
Options: ["no", "lzma", "gzip", "bz2"]
"""
from slim_trees.lgbm_booster import dump

dump_compressed(model, path, compression, dump)
dump_compressed(model, file, compression, dump)


def dumps_lgbm_compressed(
Expand Down
34 changes: 17 additions & 17 deletions slim_trees/pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import pickle
from collections.abc import Callable
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union


class _NoCompression:
Expand Down Expand Up @@ -51,7 +51,7 @@ def _get_default_kwargs(compression_method: str) -> Dict[str, Any]:

def _unpack_compression_args(
compression: Optional[Union[str, Dict[str, Any]]] = None,
path: Optional[Union[str, pathlib.Path]] = None,
file: Optional[Union[str, pathlib.Path, BinaryIO]] = None,
) -> Tuple[str, dict]:
if compression is not None:
if isinstance(compression, str):
Expand All @@ -61,24 +61,24 @@ def _unpack_compression_args(
k: compression[k] for k in compression if k != "method"
}
raise ValueError("compression must be either a string or a dict")
if path is not None:
if file is not None and isinstance(file, (str, pathlib.Path)):
# try to find out the compression using the file extension
compression_method = _get_compression_from_path(path)
compression_method = _get_compression_from_path(file)
return compression_method, _get_default_kwargs(compression_method)
raise ValueError("path or compression must not be None.")
raise ValueError("File must be a path or compression must not be None.")


def dump_compressed(
obj: Any,
path: Union[str, pathlib.Path],
file: Union[str, pathlib.Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
dump_function: Optional[Callable] = None,
):
"""
Pickles a model and saves it to the disk. If compression is not specified,
the compression method will be determined by the file extension.
:param obj: the object to pickle
:param path: where to save the object
:param file: where to save the object, either a path or a file object
:param compression: the compression method used. Either a string or a dict with key 'method' set
to the compression method and other key-value pairs are forwarded to open()
of the compression library.
Expand All @@ -89,11 +89,11 @@ def dump_compressed(
if dump_function is None:
dump_function = pickle.dump

compression_method, kwargs = _unpack_compression_args(compression, path)
compression_method, kwargs = _unpack_compression_args(compression, file)
with _get_compression_library(compression_method).open(
path, mode="wb", **kwargs
) as file:
dump_function(obj, file)
file, mode="wb", **kwargs
) as fd:
dump_function(obj, fd)


def dumps_compressed(
Expand Down Expand Up @@ -124,13 +124,13 @@ def dumps_compressed(


def load_compressed(
path: Union[str, pathlib.Path],
file: Union[str, pathlib.Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
unpickler_class: type = pickle.Unpickler,
) -> Any:
"""
Loads a compressed model.
:param path: where to load the object from
:param file: where to load the object from, either a path or a file object
:param compression: the compression method used. Either a string or a dict with key 'method'
set to the compression method and other key-value pairs which are forwarded
to open() of the compression library.
Expand All @@ -139,11 +139,11 @@ def load_compressed(
This is useful to restrict possible imports or to allow unpickling
when required module or function names have been refactored.
"""
compression_method, kwargs = _unpack_compression_args(compression, path)
compression_method, kwargs = _unpack_compression_args(compression, file)
with _get_compression_library(compression_method).open(
path, mode="rb", **kwargs
) as file:
return unpickler_class(file).load()
file, mode="rb", **kwargs
) as fd:
return unpickler_class(fd).load()


def loads_compressed(
Expand Down
4 changes: 2 additions & 2 deletions slim_trees/sklearn_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def dumps(model: Any) -> bytes:
def _tree_pickle(tree: Tree):
assert isinstance(tree, Tree)
reconstructor, args, state = tree.__reduce__()
compressed_state = _compress_tree_state(state)
compressed_state = _compress_tree_state(state) # type: ignore
return _tree_unpickle, (reconstructor, args, (slim_trees_version, compressed_state))


Expand Down Expand Up @@ -113,7 +113,7 @@ def _compress_tree_state(state: Dict) -> Dict:
"values": values,
},
**(
{"missing_go_to_left": np.packbits(missing_go_to_left)}
{"missing_go_to_left": np.packbits(missing_go_to_left)} # type: ignore
if sklearn_version_ge_130
else {}
),
Expand Down
15 changes: 15 additions & 0 deletions tests/test_lgbm_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,19 @@ def test_loads_compressed_custom_unpickler(lgbm_regressor):
loads_compressed(compressed, unpickler_class=_TestUnpickler)


def test_dump_and_load_from_file(tmp_path, lgbm_regressor):
with (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_lgbm_compressed(lgbm_regressor, file, compression="lzma")

with (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file, compression="lzma")

# No compression method specified
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file)

with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_lgbm_compressed(lgbm_regressor, file)


# todo add tests for large models
15 changes: 15 additions & 0 deletions tests/test_sklearn_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,19 @@ def test_loads_compressed_custom_unpickler(random_forest_regressor):
loads_compressed(compressed, unpickler_class=_TestUnpickler)


def test_dump_and_load_from_file(tmp_path, random_forest_regressor):
with (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_sklearn_compressed(random_forest_regressor, file, compression="lzma")

with (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file, compression="lzma")

# No compression method specified
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file)

with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_sklearn_compressed(random_forest_regressor, file)


# todo add tests for large models