Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save the fitted RandomForestRegressor model? #313

Open
heaynking opened this issue Sep 27, 2023 · 2 comments
Open

How to save the fitted RandomForestRegressor model? #313

heaynking opened this issue Sep 27, 2023 · 2 comments

Comments

@heaynking
Copy link

How can I save a lolopy model?"

I tried to train a model like this:

from lolopy.learners import RandomForestRegressor
model = RandomForestRegressor()
model.fit(X, Y)

After that, I attempted to save the model using:

joblib.dump(model, "./model.pkl")

But, it didn't work with the following error:

AttributeError: 'RandomForestRegressor' object has no attribute 'gateway'

Thank you for sharing your great program.

@kyledmiller
Copy link

I have the same issue. Here are some additional details that might be helpful for debugging.

Version

3.0.0 (current pip version)

Issue

Lolopy fails to pickle itself when trying to save a trained model with joblib

Minimal code to reproduce

from lolopy.learners import RandomForestRegressor as LoloRandomForestRegressor
import joblib
import numpy as np

X = np.random.rand(20,5)
y = np.random.rand(20,1)

model = LoloRandomForestRegressor()
model.fit(X, y)
joblib.dump(model, 'model.joblib')

Error Message

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 10
      8 model = LoloRandomForestRegressor()
      9 model.fit(X, y)
---> 10 joblib.dump(model, 'model.joblib')

File ~/Applications/miniforge3/envs/pmg/lib/python3.12/site-packages/joblib/numpy_pickle.py:553, in dump(value, filename, compress, protocol, cache_size)
    551 elif is_filename:
    552     with open(filename, 'wb') as f:
--> 553         NumpyPickler(f, protocol=protocol).dump(value)
    554 else:
    555     NumpyPickler(filename, protocol=protocol).dump(value)

File ~/Applications/miniforge3/envs/pmg/lib/python3.12/pickle.py:481, in _Pickler.dump(self, obj)
    479 if self.proto >= 4:
    480     self.framer.start_framing()
--> 481 self.save(obj)
    482 self.write(STOP)
    483 self.framer.end_framing()

File ~/Applications/miniforge3/envs/pmg/lib/python3.12/site-packages/joblib/numpy_pickle.py:355, in NumpyPickler.save(self, obj)
    352     wrapper.write_array(obj, self)
    353     return
--> 355 return Pickler.save(self, obj)

File ~/Applications/miniforge3/envs/pmg/lib/python3.12/pickle.py:572, in _Pickler.save(self, obj, save_persistent_id)
    570 reduce = getattr(obj, "__reduce_ex__", None)
    571 if reduce is not None:
--> 572     rv = reduce(self.proto)
    573 else:
    574     reduce = getattr(obj, "__reduce__", None)

File ~/Applications/miniforge3/envs/pmg/lib/python3.12/site-packages/lolopy/learners.py:59, in BaseLoloLearner.__getstate__(self)
     57 # If there is a model set, replace it with the JVM copy
     58 if self.model_ is not None:
---> 59     state['model_'] = self.gateway.jvm.io.citrine.lolo.util.LoloPyDataLoader.serializeObject(self.model_,
     60                                                                                              self._compress_level)
     61 return state

AttributeError: 'RandomForestRegressor' object has no attribute 'gateway'

@kyledmiller
Copy link

Solved by #318. Use the new save, load methods instead of joblib dump for now.

Updated Example Code -- working

from lolopy.learners import RandomForestRegressor as LoloRandomForestRegressor
import numpy as np

X = np.random.rand(20,5)
y = np.random.rand(20,1)

model = LoloRandomForestRegressor()
model.fit(X, y)
print(model.predict(X))

### Save
model.save('model.lolopy.rfr')

### Load
model = LoloRandomForestRegressor.load('model.lolopy.rfr')
print(model.predict(X))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants