Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance cross-platform compatibility for loading PySRRegressor models #681

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from datetime import datetime
from io import StringIO
from multiprocessing import cpu_count
from pathlib import Path
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast

import numpy as np
import pandas as pd
from numpy import ndarray
from numpy.typing import NDArray
from numpy.typing import ArrayLike, NDArray
zzccchen marked this conversation as resolved.
Show resolved Hide resolved
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
from sklearn.utils import check_array, check_consistent_length, check_random_state
from sklearn.utils.validation import _check_feature_names_in # type: ignore
Expand Down Expand Up @@ -949,7 +949,7 @@ def __init__(
@classmethod
def from_file(
cls,
equation_file: PathLike,
equation_file: Union[str, Path],
zzccchen marked this conversation as resolved.
Show resolved Hide resolved
*,
binary_operators: Optional[List[str]] = None,
unary_operators: Optional[List[str]] = None,
Expand Down Expand Up @@ -997,6 +997,20 @@ def from_file(
The model with fitted equations.
"""

class CustomUnpickler(pkl.Unpickler):
zzccchen marked this conversation as resolved.
Show resolved Hide resolved
def find_class(self, module, name):
if module == "pathlib":
if name == "PosixPath":
return PurePosixPath
elif name == "WindowsPath":
return PureWindowsPath
return super().find_class(module, name)
zzccchen marked this conversation as resolved.
Show resolved Hide resolved

def path_to_str(path):
if isinstance(path, (PurePosixPath, PureWindowsPath)):
return str(path)
return path
zzccchen marked this conversation as resolved.
Show resolved Hide resolved

pkl_filename = _csv_filename_to_pkl_filename(equation_file)

# Try to load model from <equation_file>.pkl
Expand All @@ -1007,11 +1021,10 @@ def from_file(
assert unary_operators is None
assert n_features_in is None
with open(pkl_filename, "rb") as f:
model = pkl.load(f)
# Change equation_file_ to be in the same dir as the pickle file
base_dir = os.path.dirname(pkl_filename)
base_equation_file = os.path.basename(model.equation_file_)
model.equation_file_ = os.path.join(base_dir, base_equation_file)
unpickler = CustomUnpickler(f)
model = unpickler.load()
# Convert equation_file_ to string to ensure cross-platform compatibility
model.equation_file_ = path_to_str(model.equation_file_)

# Update any parameters if necessary, such as
# extra_sympy_mappings:
Expand Down
Loading