diff --git a/pyproject.toml b/pyproject.toml index d01da8c..9fefc89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "slim-trees" description = "A python package for efficient pickling of ML models." -version = "0.2.11" +version = "0.2.12" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.8" diff --git a/slim_trees/__init__.py b/slim_trees/__init__.py index b71f4f4..221aedb 100644 --- a/slim_trees/__init__.py +++ b/slim_trees/__init__.py @@ -13,7 +13,7 @@ import importlib.metadata import warnings from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, BinaryIO, Optional, Union from slim_trees.pickling import ( dump_compressed, @@ -42,14 +42,16 @@ def dump_sklearn_compressed( - model: Any, path: Union[str, Path], compression: Optional[Union[str, dict]] = None + model: Any, + file: Union[str, Path, BinaryIO], + compression: Optional[Union[str, dict]] = None, ): """ Pickles a model and saves a compressed version to the disk. Saves the parameters of the model as int16 and float32 instead of int64 and float64. :param model: the model to save - :param path: where to save the model + :param file: where to save the model, either a path or a file object :param compression: the compression method used. Either a string or a dict with key 'method' set to the compression method and other key-value pairs are forwarded to `open` of the compression library. @@ -57,7 +59,7 @@ def dump_sklearn_compressed( """ from slim_trees.sklearn_tree import dump - dump_compressed(model, path, compression, dump) + dump_compressed(model, file, compression, dump) def dumps_sklearn_compressed( @@ -79,14 +81,16 @@ def dumps_sklearn_compressed( def dump_lgbm_compressed( - model: Any, path: Union[str, Path], compression: Optional[Union[str, dict]] = None + model: Any, + file: Union[str, Path, BinaryIO], + compression: Optional[Union[str, dict]] = None, ): """ Pickles a model and saves a compressed version to the disk. Saves the parameters of the model as int16 and float32 instead of int64 and float64. :param model: the model to save - :param path: where to save the model + :param file: where to save the model, either a path or a file object :param compression: the compression method used. Either a string or a dict with key 'method' set to the compression method and other key-value pairs are forwarded to `open` of the compression library. @@ -94,7 +98,7 @@ def dump_lgbm_compressed( """ from slim_trees.lgbm_booster import dump - dump_compressed(model, path, compression, dump) + dump_compressed(model, file, compression, dump) def dumps_lgbm_compressed( diff --git a/slim_trees/pickling.py b/slim_trees/pickling.py index f3af990..a324cad 100644 --- a/slim_trees/pickling.py +++ b/slim_trees/pickling.py @@ -5,7 +5,7 @@ import pathlib import pickle from collections.abc import Callable -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, BinaryIO, Dict, Optional, Tuple, Union class _NoCompression: @@ -51,7 +51,7 @@ def _get_default_kwargs(compression_method: str) -> Dict[str, Any]: def _unpack_compression_args( compression: Optional[Union[str, Dict[str, Any]]] = None, - path: Optional[Union[str, pathlib.Path]] = None, + file: Optional[Union[str, pathlib.Path, BinaryIO]] = None, ) -> Tuple[str, dict]: if compression is not None: if isinstance(compression, str): @@ -61,16 +61,16 @@ def _unpack_compression_args( k: compression[k] for k in compression if k != "method" } raise ValueError("compression must be either a string or a dict") - if path is not None: + if file is not None and isinstance(file, (str, pathlib.Path)): # try to find out the compression using the file extension - compression_method = _get_compression_from_path(path) + compression_method = _get_compression_from_path(file) return compression_method, _get_default_kwargs(compression_method) - raise ValueError("path or compression must not be None.") + raise ValueError("File must be a path or compression must not be None.") def dump_compressed( obj: Any, - path: Union[str, pathlib.Path], + file: Union[str, pathlib.Path, BinaryIO], compression: Optional[Union[str, dict]] = None, dump_function: Optional[Callable] = None, ): @@ -78,7 +78,7 @@ def dump_compressed( Pickles a model and saves it to the disk. If compression is not specified, the compression method will be determined by the file extension. :param obj: the object to pickle - :param path: where to save the object + :param file: where to save the object, either a path or a file object :param compression: the compression method used. Either a string or a dict with key 'method' set to the compression method and other key-value pairs are forwarded to open() of the compression library. @@ -89,11 +89,11 @@ def dump_compressed( if dump_function is None: dump_function = pickle.dump - compression_method, kwargs = _unpack_compression_args(compression, path) + compression_method, kwargs = _unpack_compression_args(compression, file) with _get_compression_library(compression_method).open( - path, mode="wb", **kwargs - ) as file: - dump_function(obj, file) + file, mode="wb", **kwargs + ) as fd: + dump_function(obj, fd) def dumps_compressed( @@ -124,13 +124,13 @@ def dumps_compressed( def load_compressed( - path: Union[str, pathlib.Path], + file: Union[str, pathlib.Path, BinaryIO], compression: Optional[Union[str, dict]] = None, unpickler_class: type = pickle.Unpickler, ) -> Any: """ Loads a compressed model. - :param path: where to load the object from + :param file: where to load the object from, either a path or a file object :param compression: the compression method used. Either a string or a dict with key 'method' set to the compression method and other key-value pairs which are forwarded to open() of the compression library. @@ -139,11 +139,11 @@ def load_compressed( This is useful to restrict possible imports or to allow unpickling when required module or function names have been refactored. """ - compression_method, kwargs = _unpack_compression_args(compression, path) + compression_method, kwargs = _unpack_compression_args(compression, file) with _get_compression_library(compression_method).open( - path, mode="rb", **kwargs - ) as file: - return unpickler_class(file).load() + file, mode="rb", **kwargs + ) as fd: + return unpickler_class(fd).load() def loads_compressed( diff --git a/slim_trees/sklearn_tree.py b/slim_trees/sklearn_tree.py index 71648a5..32ceccf 100644 --- a/slim_trees/sklearn_tree.py +++ b/slim_trees/sklearn_tree.py @@ -46,7 +46,7 @@ def dumps(model: Any) -> bytes: def _tree_pickle(tree: Tree): assert isinstance(tree, Tree) reconstructor, args, state = tree.__reduce__() - compressed_state = _compress_tree_state(state) + compressed_state = _compress_tree_state(state) # type: ignore return _tree_unpickle, (reconstructor, args, (slim_trees_version, compressed_state)) @@ -113,7 +113,7 @@ def _compress_tree_state(state: Dict) -> Dict: "values": values, }, **( - {"missing_go_to_left": np.packbits(missing_go_to_left)} + {"missing_go_to_left": np.packbits(missing_go_to_left)} # type: ignore if sklearn_version_ge_130 else {} ), diff --git a/tests/test_lgbm_compression.py b/tests/test_lgbm_compression.py index 64037e0..31b00dd 100644 --- a/tests/test_lgbm_compression.py +++ b/tests/test_lgbm_compression.py @@ -107,4 +107,19 @@ def test_loads_compressed_custom_unpickler(lgbm_regressor): loads_compressed(compressed, unpickler_class=_TestUnpickler) +def test_dump_and_load_from_file(tmp_path, lgbm_regressor): + with (tmp_path / "model.pickle.lzma").open("wb") as file: + dump_lgbm_compressed(lgbm_regressor, file, compression="lzma") + + with (tmp_path / "model.pickle.lzma").open("rb") as file: + load_compressed(file, compression="lzma") + + # No compression method specified + with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("rb") as file: + load_compressed(file) + + with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("wb") as file: + dump_lgbm_compressed(lgbm_regressor, file) + + # todo add tests for large models diff --git a/tests/test_sklearn_compression.py b/tests/test_sklearn_compression.py index 0cb649c..28192ef 100644 --- a/tests/test_sklearn_compression.py +++ b/tests/test_sklearn_compression.py @@ -159,4 +159,19 @@ def test_loads_compressed_custom_unpickler(random_forest_regressor): loads_compressed(compressed, unpickler_class=_TestUnpickler) +def test_dump_and_load_from_file(tmp_path, random_forest_regressor): + with (tmp_path / "model.pickle.lzma").open("wb") as file: + dump_sklearn_compressed(random_forest_regressor, file, compression="lzma") + + with (tmp_path / "model.pickle.lzma").open("rb") as file: + load_compressed(file, compression="lzma") + + # No compression method specified + with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("rb") as file: + load_compressed(file) + + with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("wb") as file: + dump_sklearn_compressed(random_forest_regressor, file) + + # todo add tests for large models