Skip to content

Commit

Permalink
[ext/sklearn] updated parameter_search (fixes a undetected test failu…
Browse files Browse the repository at this point in the history
…re in doctest) (markovmodel#829)
  • Loading branch information
marscher authored Jun 20, 2016
1 parent 7dc5f7f commit 9b34382
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 59 deletions.
19 changes: 0 additions & 19 deletions pyemma/_ext/sklearn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +0,0 @@

# This file is part of PyEMMA.
#
# Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
#
# PyEMMA is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

__author__ = 'noe'
18 changes: 0 additions & 18 deletions pyemma/_ext/sklearn/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,3 @@

# This file is part of PyEMMA.
#
# Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
#
# PyEMMA is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""
--------------------------------------------------------------------------------------------
Extracted from skikit-learn to ensure basic compatibility
Expand Down
82 changes: 60 additions & 22 deletions pyemma/_ext/sklearn/parameter_search.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,3 @@

# This file is part of PyEMMA.
#
# Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
#
# PyEMMA is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""
--------------------------------------------------------------------------------------------
Extracted from skikit-learn to ensure basic compatibility
Expand Down Expand Up @@ -43,20 +25,30 @@
from functools import reduce
import operator
from six.moves import zip
import numpy as np



class ParameterGrid(object):
"""Grid of parameters with a discrete number of values for each.
Can be used to iterate over parameter value combinations with the
Python built-in function iter.
Read more in the :ref:`User Guide <grid_search>`.
Parameters
----------
param_grid : dict of string to sequence, or sequence of such
The parameter grid to explore, as a dictionary mapping estimator
parameters to sequences of allowed values.
An empty dict signifies default parameters.
A sequence of dicts signifies a sequence of grids to search, and is
useful to avoid exploring parameter combinations that make no sense
or have no effect. See the examples below.
Examples
--------
>>> from sklearn.grid_search import ParameterGrid
Expand All @@ -65,11 +57,15 @@ class ParameterGrid(object):
... [{'a': 1, 'b': True}, {'a': 1, 'b': False},
... {'a': 2, 'b': True}, {'a': 2, 'b': False}])
True
>>> grid = [{'kernel': ['linear']}, {'kernel': ['rbf'], 'gamma': [1, 10]}]
>>> list(ParameterGrid(grid)) == [{'kernel': 'linear'},
... {'kernel': 'rbf', 'gamma': 1},
... {'kernel': 'rbf', 'gamma': 10}]
True
>>> ParameterGrid(grid)[1] == {'kernel': 'rbf', 'gamma': 1}
True
See also
--------
:class:`GridSearchCV`:
Expand All @@ -85,6 +81,7 @@ def __init__(self, param_grid):

def __iter__(self):
"""Iterate over the points in the grid.
Returns
-------
params : iterator over dict of string to any
Expand All @@ -97,14 +94,55 @@ def __iter__(self):
if not items:
yield {}
else:
keys, values = list(zip(*items))
keys, values = zip(*items)
for v in product(*values):
params = dict(list(zip(keys, v)))
params = dict(zip(keys, v))
yield params

def __len__(self):
"""Number of points on the grid."""
# Product function that can handle iterables (np.product can't).
product = partial(reduce, operator.mul)
return sum(product(len(v) for v in list(p.values())) if p else 1
for p in self.param_grid)
return sum(product(len(v) for v in p.values()) if p else 1
for p in self.param_grid)

def __getitem__(self, ind):
"""Get the parameters that would be ``ind``th in iteration
Parameters
----------
ind : int
The iteration index
Returns
-------
params : dict of string to any
Equal to list(self)[ind]
"""
# This is used to make discrete sampling without replacement memory
# efficient.
for sub_grid in self.param_grid:
# XXX: could memoize information used here
if not sub_grid:
if ind == 0:
return {}
else:
ind -= 1
continue

# Reverse so most frequent cycling parameter comes first
keys, values_lists = zip(*sorted(sub_grid.items())[::-1])
sizes = [len(v_list) for v_list in values_lists]
total = np.product(sizes)

if ind >= total:
# Try the next grid
ind -= total
else:
out = {}
for key, v_list, n in zip(keys, values_lists, sizes):
ind, offset = divmod(ind, n)
out[key] = v_list[offset]
return out

raise IndexError('ParameterGrid index out of range')

0 comments on commit 9b34382

Please sign in to comment.