-
Notifications
You must be signed in to change notification settings - Fork 3
/
BufferedEstimator.py
61 lines (47 loc) · 1.89 KB
/
BufferedEstimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python
# CREATED:2013-07-10 10:14:41 by Brian McFee <[email protected]>
# wrapper for sklearn estimators to buffer generator output for use with stochastic
# optimization via partial_fit()
#
# from https://github.com/bmcfee/ml_scraps
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
class BufferedEstimator(BaseEstimator):
def __init__(self, estimator, batch_size=256, sparse=False):
"""
:parameters:
- estimator : sklearn.BaseEstimator
Any classifier/transformer that supports a partial_fit method
- batch_size : int > 0
The amount of data to buffer for each call to partial_fit
- sparse : boolean
Is the data generator sparse?
"""
# Make sure that the estimator is of the right base type
if not isinstance(estimator, BaseEstimator):
raise TypeError('estimator must extend from sklearn.base.BaseEstimator')
self.estimator = estimator
self.batch_size = batch_size
self.sparse = sparse
# are we classifying or transforming?
self.supervised = isinstance(estimator, ClassifierMixin)
# this will only work if the estimator supports partial_fit
assert hasattr(estimator, 'partial_fit')
def fit(self, generator):
def _run(X):
if self.supervised:
y = np.array([z[-1] for z in X])
X = np.array([z[0] for z in X])
self.estimator.partial_fit(X, y)
else:
X = np.array(X)
self.estimator.partial_fit(X)
X = []
for (i, x_new) in enumerate(generator):
X.append(x_new)
if len(X) == self.batch_size:
_run(X)
X = []
# Fit the last batch, if there is one
if len(X) > 0:
_run(X)