Skip to content

Commit

Permalink
Clean deprecations for scikit-learn 1.0 | Kmeans (#651)
Browse files Browse the repository at this point in the history
  • Loading branch information
OnlyDeniko authored May 15, 2021
1 parent edafa1d commit a78649b
Showing 1 changed file with 44 additions and 31 deletions.
75 changes: 44 additions & 31 deletions daal4py/sklearn/cluster/_k_means_0_23.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,26 +240,27 @@ def _fit(self, X, y=None, sample_weight=None):
are assigned equal weight (default: None)
"""
if self.precompute_distances != 'deprecated':
if sklearn_check_version('0.24'):
warnings.warn("'precompute_distances' was deprecated in version "
"0.23 and will be removed in 1.0 (renaming of 0.25). It has no "
"effect", FutureWarning)
elif sklearn_check_version('0.23'):
warnings.warn("'precompute_distances' was deprecated in version "
"0.23 and will be removed in 0.25. It has no "
"effect", FutureWarning)

if self.n_jobs != 'deprecated':
if sklearn_check_version('0.24'):
warnings.warn("'n_jobs' was deprecated in version 0.23 and will be"
" removed in 1.0 (renaming of 0.25).", FutureWarning)
elif sklearn_check_version('0.23'):
warnings.warn("'n_jobs' was deprecated in version 0.23 and will be"
" removed in 0.25.", FutureWarning)
self._n_threads = self.n_jobs
else:
self._n_threads = None
if hasattr(self, 'precompute_distances'):
if self.precompute_distances != 'deprecated':
if sklearn_check_version('0.24'):
warnings.warn("'precompute_distances' was deprecated in version "
"0.23 and will be removed in 1.0 (renaming of 0.25)."
" It has no effect", FutureWarning)
elif sklearn_check_version('0.23'):
warnings.warn("'precompute_distances' was deprecated in version "
"0.23 and will be removed in 0.25. It has no "
"effect", FutureWarning)

self._n_threads = None
if hasattr(self, 'n_jobs'):
if self.n_jobs != 'deprecated':
if sklearn_check_version('0.24'):
warnings.warn("'n_jobs' was deprecated in version 0.23 and will be"
" removed in 1.0 (renaming of 0.25).", FutureWarning)
elif sklearn_check_version('0.23'):
warnings.warn("'n_jobs' was deprecated in version 0.23 and will be"
" removed in 0.25.", FutureWarning)
self._n_threads = self.n_jobs
self._n_threads = _openmp_effective_n_threads(self._n_threads)

if self.n_init <= 0:
Expand Down Expand Up @@ -366,17 +367,29 @@ def _predict(self, X, sample_weight=None):
class KMeans(KMeans_original):
__doc__ = KMeans_original.__doc__

@_deprecate_positional_args
def __init__(self, n_clusters=8, *, init='k-means++', n_init=10,
max_iter=300, tol=1e-4, precompute_distances='deprecated',
verbose=0, random_state=None, copy_x=True,
n_jobs='deprecated', algorithm='auto'):

super(KMeans, self).__init__(
n_clusters=n_clusters, init=init, max_iter=max_iter,
tol=tol, precompute_distances=precompute_distances,
n_init=n_init, verbose=verbose, random_state=random_state,
copy_x=copy_x, n_jobs=n_jobs, algorithm=algorithm)
if sklearn_check_version('1.0'):
@_deprecate_positional_args
def __init__(self, n_clusters=8, *, init='k-means++', n_init=10,
max_iter=300, tol=1e-4, verbose=0, random_state=None,
copy_x=True, algorithm='auto'):

super(KMeans, self).__init__(
n_clusters=n_clusters, init=init, max_iter=max_iter,
tol=tol, n_init=n_init, verbose=verbose,
random_state=random_state, copy_x=copy_x,
algorithm=algorithm)
else:
@_deprecate_positional_args
def __init__(self, n_clusters=8, *, init='k-means++', n_init=10,
max_iter=300, tol=1e-4, precompute_distances='deprecated',
verbose=0, random_state=None, copy_x=True,
n_jobs='deprecated', algorithm='auto'):

super(KMeans, self).__init__(
n_clusters=n_clusters, init=init, max_iter=max_iter,
tol=tol, precompute_distances=precompute_distances,
n_init=n_init, verbose=verbose, random_state=random_state,
copy_x=copy_x, n_jobs=n_jobs, algorithm=algorithm)

def fit(self, X, y=None, sample_weight=None):
return _fit(self, X, y=y, sample_weight=sample_weight)
Expand Down

0 comments on commit a78649b

Please sign in to comment.