diff --git a/skglm/gpu/README.md b/skglm/gpu/README.md new file mode 100644 index 000000000..6fc49b761 --- /dev/null +++ b/skglm/gpu/README.md @@ -0,0 +1,37 @@ +## Installation + +1. checkout branch +```shell +# add remote if it does't exist (check with: git remote -v) +git remote add Badr-MOUFAD https://github.com/Badr-MOUFAD/skglm.git + +git fetch Badr-MOUFAD skglm-gpu + +git checkout skglm-gpu +``` + +2. create then activate``conda`` environnement +```shell +# create +conda create -n skglm-gpu python=3.7 + +# activate env +conda activate skglm-gpu +``` + +3. install ``skglm`` in editable mode +```shell +pip install skglm -e . +``` + +4. install dependencies +```shell +# cupy +conda install -c conda-forge cupy cudatoolkit=11.5 + +# pytorch +pip install torch + +# jax +conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia +``` diff --git a/skglm/gpu/__init__.py b/skglm/gpu/__init__.py new file mode 100644 index 000000000..c7fda3ae0 --- /dev/null +++ b/skglm/gpu/__init__.py @@ -0,0 +1,6 @@ +"""Solve Lasso problem using FISTA GPU-implementation. + +Problem reads:: + + min_w (1/2n) * ||y - Xw||^2 + lmbd * ||w||_1 +""" diff --git a/skglm/gpu/example.py b/skglm/gpu/example.py new file mode 100644 index 000000000..3ab55a215 --- /dev/null +++ b/skglm/gpu/example.py @@ -0,0 +1,66 @@ +import time +import warnings + +import numpy as np +from numpy.linalg import norm + +from benchopt.datasets import make_correlated_data + +from skglm.gpu.solvers import NumbaSolver, CPUSolver + +from skglm.gpu.utils.host_utils import compute_obj, eval_opt_crit + +from numba.core.errors import NumbaPerformanceWarning +warnings.filterwarnings('ignore', category=NumbaPerformanceWarning) + + +random_state = 27 +n_samples, n_features = 10_000, 500 +reg = 1e-2 + +# generate dummy data +rng = np.random.RandomState(random_state) + +X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng) + + +# set lambda +lmbd_max = norm(X.T @ y, ord=np.inf) +lmbd = reg * lmbd_max + +max_iter = 0 + +# cache numba compilation +NumbaSolver(verbose=0, max_iter=2).solve(X, y, lmbd) + +solver_gpu = NumbaSolver(verbose=1, max_iter=max_iter) +# solve problem +start = time.perf_counter() +w_gpu = solver_gpu.solve(X, y, lmbd) +end = time.perf_counter() + +print("gpu time: ", end - start) + + +# cache numba compilation +CPUSolver(max_iter=2).solve(X, y, lmbd) + +solver_cpu = CPUSolver(verbose=1, max_iter=max_iter) +start = time.perf_counter() +w_cpu = solver_cpu.solve(X, y, lmbd) +end = time.perf_counter() +print("cpu time: ", end - start) + + +print( + "Objective\n" + f"gpu : {compute_obj(X, y, lmbd, w_gpu):.8f}\n" + f"cpu : {compute_obj(X, y, lmbd, w_cpu):.8f}" +) + + +print( + "Optimality condition\n" + f"gpu : {eval_opt_crit(X, y, lmbd, w_gpu):.8f}\n" + f"cpu : {eval_opt_crit(X, y, lmbd, w_cpu):.8f}" +) diff --git a/skglm/gpu/solvers/__init__.py b/skglm/gpu/solvers/__init__.py new file mode 100644 index 000000000..e1cc3574f --- /dev/null +++ b/skglm/gpu/solvers/__init__.py @@ -0,0 +1,4 @@ +from skglm.gpu.solvers.cpu_solver import CPUSolver # noqa +from skglm.gpu.solvers.cupy_solver import CupySolver # noqa +from skglm.gpu.solvers.jax_solver import JaxSolver # noqa +from skglm.gpu.solvers.numba_solver import NumbaSolver # noqa diff --git a/skglm/gpu/solvers/base.py b/skglm/gpu/solvers/base.py new file mode 100644 index 000000000..4082a7f82 --- /dev/null +++ b/skglm/gpu/solvers/base.py @@ -0,0 +1,66 @@ +from numba import njit +from abc import abstractmethod + +import numpy as np +from scipy import sparse +from scipy.sparse import linalg as spicy_linalg + +from skglm.utils.prox_funcs import ST_vec + + +class BaseFistaSolver: + + @abstractmethod + def solve(self, X, y, datafit, penalty): + ... + + +class BaseQuadratic: + + def value(self, X, y, w, Xw): + """parameters are numpy/scipy arrays.""" + return ((y - X @ w) ** 2).sum() / (2 * len(y)) + + def gradient(self, X, y, w, Xw): + return X.T @ (Xw - y) / len(y) + + def get_lipschitz_cst(self, X): + n_samples = X.shape[0] + + if sparse.issparse(X): + return spicy_linalg.svds(X, k=1)[1][0] ** 2 / n_samples + + return np.linalg.norm(X, ord=2) ** 2 / n_samples + + +class BaseL1: + + def __init__(self, alpha): + self.alpha = alpha + + def value(self, w): + return self.alpha * np.abs(w).sum() + + def prox(self, value, stepsize): + return ST_vec(value, self.alpha * stepsize) + + def max_subdiff_distance(self, w, grad): + return BaseL1._compute_max_subdiff_distance(w, grad, self.alpha) + + @staticmethod + @njit("f8(f8[:], f8[:], f8)") + def _compute_max_subdiff_distance(w, grad, lmbd): + max_dist = 0. + + for i in range(len(w)): + grad_i = grad[i] + w_i = w[i] + + if w[i] == 0.: + dist = max(abs(grad_i) - lmbd, 0.) + else: + dist = abs(grad_i + np.sign(w_i) * lmbd) + + max_dist = max(max_dist, dist) + + return max_dist diff --git a/skglm/gpu/solvers/cpu_solver.py b/skglm/gpu/solvers/cpu_solver.py new file mode 100644 index 000000000..779627866 --- /dev/null +++ b/skglm/gpu/solvers/cpu_solver.py @@ -0,0 +1,55 @@ +import numpy as np +from skglm.gpu.solvers.base import BaseFistaSolver + + +class CPUSolver(BaseFistaSolver): + + def __init__(self, max_iter=1000, verbose=0): + self.max_iter = max_iter + self.verbose = verbose + + def solve(self, X, y, datafit, penalty): + n_samples, n_features = X.shape + + # compute step + lipschitz = datafit.get_lipschitz_cst(X) + if lipschitz == 0.: + return np.zeros(n_features) + + step = 1 / lipschitz + + # init vars + w = np.zeros(n_features) + old_w = np.zeros(n_features) + mid_w = np.zeros(n_features) + grad = np.zeros(n_features) + + t_old, t_new = 1, 1 + + for it in range(self.max_iter): + + # compute grad + grad = datafit.gradient(X, y, mid_w, X @ mid_w) + + # forward / backward + w = penalty.prox(mid_w - step * grad, step) + + if self.verbose: + p_obj = datafit.value(X, y, w, X @ w) + penalty.value(w) + + grad = datafit.gradient(X, y, w, X @ w) + opt_crit = penalty.max_subdiff_distance(w, grad) + + print( + f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}" + ) + + # extrapolate + mid_w = w + ((t_old - 1) / t_new) * (w - old_w) + + # update FISTA vars + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + old_w = np.copy(w) + + return w diff --git a/skglm/gpu/solvers/cupy_solver.py b/skglm/gpu/solvers/cupy_solver.py new file mode 100644 index 000000000..0cdb72303 --- /dev/null +++ b/skglm/gpu/solvers/cupy_solver.py @@ -0,0 +1,81 @@ +import cupy as cp +import cupyx.scipy.sparse as cpx + +import numpy as np +from scipy import sparse + +from skglm.gpu.solvers.base import BaseFistaSolver, BaseL1, BaseQuadratic + + +class CupySolver(BaseFistaSolver): + + def __init__(self, max_iter=1000, verbose=0): + self.max_iter = max_iter + self.verbose = verbose + + def solve(self, X, y, datafit, penalty): + n_samples, n_features = X.shape + + # compute step + lipschitz = datafit.get_lipschitz_cst(X) + if lipschitz == 0.: + return np.zeros(n_features) + + step = 1 / lipschitz + + is_X_sparse = sparse.issparse(X) + + # transfer to device + X_gpu = cp.array(X) if not is_X_sparse else cpx.csr_matrix(X) + y_gpu = cp.array(y) + + # init vars in device + w = cp.zeros(n_features) + old_w = cp.zeros(n_features) + mid_w = cp.zeros(n_features) + grad = cp.zeros(n_features) + + t_old, t_new = 1, 1 + + for it in range(self.max_iter): + + # compute grad + grad = datafit.gradient(X_gpu, y_gpu, mid_w, X_gpu @ mid_w) + + # forward / backward + val = mid_w - step * grad + w = penalty.prox(val, step) + + if self.verbose: + p_obj = datafit.value(X_gpu, y_gpu, w, X_gpu @ w) + penalty.value(w) + + w_cpu = cp.asnumpy(w) + grad = datafit.gradient(X, y, w_cpu, X @ w_cpu) + opt_crit = penalty.max_subdiff_distance(w_cpu, grad) + + print( + f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}" + ) + + # extrapolate + mid_w = w + ((t_old - 1) / t_new) * (w - old_w) + + # update FISTA vars + t_old = t_new + t_new = (1 + cp.sqrt(1 + 4 * t_old ** 2)) / 2 + old_w = cp.copy(w) + + # transfer back to host + w_cpu = cp.asnumpy(w) + + return w_cpu + + +class QuadraticCuPy(BaseQuadratic): + pass + + +class L1CuPy(BaseL1): + + def prox(self, value, stepsize): + return cp.sign(value) * cp.maximum(cp.abs(value) - stepsize * self.alpha, 0.) diff --git a/skglm/gpu/solvers/jax_solver.py b/skglm/gpu/solvers/jax_solver.py new file mode 100644 index 000000000..af8334c72 --- /dev/null +++ b/skglm/gpu/solvers/jax_solver.py @@ -0,0 +1,117 @@ +# if not set, raises an error related to CUDA linking API. +# as recommended, setting the 'XLA_FLAGS' to bypass it. +# side-effect: (perhaps) slow compilation time. +import os +os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa + +import numpy as np # noqa + +import jax # noqa +import jax.numpy as jnp # noqa +# set float64 as default float type. +# if not, amplifies rounding errors. +jax.config.update("jax_enable_x64", True) # noqa + +from scipy import sparse # noqa +from jax.experimental import sparse as jax_sparse # noqa + +from skglm.gpu.solvers.base import BaseFistaSolver, BaseQuadratic, BaseL1 # noqa + + +class JaxSolver(BaseFistaSolver): + + def __init__(self, max_iter=1000, use_auto_diff=True, verbose=0): + self.max_iter = max_iter + self.use_auto_diff = use_auto_diff + self.verbose = verbose + + def solve(self, X, y, datafit, penalty): + n_samples, n_features = X.shape + + # compute step + lipschitz = datafit.get_lipschitz_cst(X) + if lipschitz == 0.: + return np.zeros(n_features) + + step = 1 / lipschitz + + # transfer to device + if sparse.issparse(X): + # sparse matrices are still an experimental features in jax + # matrix operation are supported only for COO matrices but missing + # for CSC, CSR. hence working with COO in the wait for a new Jax release + # that adds support for these features + X_gpu = jax_sparse.BCOO.from_scipy_sparse(X) + else: + X_gpu = jnp.asarray(X) + y_gpu = jnp.asarray(y) + + # get grad func of datafit + if self.use_auto_diff: + auto_grad = jax.jit(jax.grad(datafit.value, argnums=-1)) + + # init vars in device + w = jnp.zeros(n_features) + old_w = jnp.zeros(n_features) + mid_w = jnp.zeros(n_features) + grad = jnp.zeros(n_features) + + t_old, t_new = 1, 1 + + for it in range(self.max_iter): + + # compute grad + if self.use_auto_diff: + grad = auto_grad(X_gpu, y_gpu, mid_w) + else: + grad = datafit.gradient(X_gpu, y_gpu, mid_w) + + # forward / backward + val = mid_w - step * grad + w = penalty.prox(val, step) + + if self.verbose: + p_obj = datafit.value(X_gpu, y_gpu, w) + penalty.value(w) + + if self.use_auto_diff: + grad = auto_grad(X_gpu, y_gpu, w) + else: + grad = datafit.gradient(X_gpu, y_gpu, w) + + w_cpu = np.asarray(w, dtype=np.float64) + grad_cpu = np.asarray(grad, dtype=np.float64) + opt_crit = penalty.max_subdiff_distance(w_cpu, grad_cpu) + + print( + f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}" + ) + + # extrapolate + mid_w = w + ((t_old - 1) / t_new) * (w - old_w) + + # update FISTA vars + t_old = t_new + t_new = 0.5 * (1 + jnp.sqrt(1. + 4. * t_old ** 2)) + old_w = jnp.copy(w) + + # transfer back to host + w_cpu = np.asarray(w, dtype=np.float64) + + return w_cpu + + +class QuadraticJax(BaseQuadratic): + + def value(self, X_gpu, y_gpu, w): + n_samples = X_gpu.shape[0] + return jnp.sum((X_gpu @ w - y_gpu) ** 2) / (2. * n_samples) + + def gradient(self, X_gpu, y_gpu, w): + n_samples = X_gpu.shape[0] + return X_gpu.T @ (X_gpu @ w - y_gpu) / n_samples + + +class L1Jax(BaseL1): + + def prox(self, value, stepsize): + return jnp.sign(value) * jnp.maximum(jnp.abs(value) - stepsize * self.alpha, 0.) diff --git a/skglm/gpu/solvers/numba_solver.py b/skglm/gpu/solvers/numba_solver.py new file mode 100644 index 000000000..96f30c5b3 --- /dev/null +++ b/skglm/gpu/solvers/numba_solver.py @@ -0,0 +1,278 @@ +import math +import numpy as np +from numba import cuda + +from scipy import sparse + +from skglm.gpu.solvers.base import BaseL1, BaseQuadratic, BaseFistaSolver + +import warnings +from numba.core.errors import NumbaPerformanceWarning + +warnings.filterwarnings("ignore", category=NumbaPerformanceWarning) + + +# Built from GPU properties +# Refer to `utils` to get GPU properties +MAX_1DIM_BLOCK = (1024,) +MAX_2DIM_BLOCK = (32, 32) +MAX_1DIM_GRID = (65535,) +MAX_2DIM_GRID = (65535, 65535) + + +class NumbaSolver(BaseFistaSolver): + + def __init__(self, max_iter=1000, verbose=0): + self.max_iter = max_iter + self.verbose = verbose + + def solve(self, X, y, datafit, penalty): + n_samples, n_features = X.shape + X_is_sparse = sparse.issparse(X) + + # compute step + lipschitz = datafit.get_lipschitz_cst(X) + if lipschitz == 0.: + return np.zeros(n_features) + + step = 1 / lipschitz + + # number of block to use along features-axis when launching kernel + n_blocks_axis_1 = math.ceil(X.shape[1] / MAX_1DIM_BLOCK[0]) + + # transfer to device + if X_is_sparse: + X_gpu_bundles = ( + cuda.to_device(X.data), + cuda.to_device(X.indptr), + cuda.to_device(X.indices), + X.shape, + ) + else: + X_gpu = cuda.to_device(X) + y_gpu = cuda.to_device(y) + + # init vars on device + # CAUTION: should be init with specific values + # otherwise, stale values in GPU memory are used + w = cuda.to_device(np.zeros(n_features)) + mid_w = cuda.to_device(np.zeros(n_features)) + old_w = cuda.to_device(np.zeros(n_features)) + + # needn't to be init with values as it stores results of computation + grad = cuda.device_array(n_features) + + t_old, t_new = 1, 1 + + for it in range(self.max_iter): + + # inplace update of grad + if X_is_sparse: + datafit.sparse_gradient(*X_gpu_bundles, y_gpu, mid_w, grad) + else: + datafit.gradient(X_gpu, y_gpu, mid_w, grad) + + # inplace update of mid_w + _forward[n_blocks_axis_1, MAX_1DIM_BLOCK](mid_w, grad, step, mid_w) + + # inplace update of w + penalty.prox(mid_w, step, w) + + if self.verbose: + w_cpu = w.copy_to_host() + + p_obj = datafit.value(X, y, w_cpu, X @ w_cpu) + penalty.value(w_cpu) + + if X_is_sparse: + datafit.sparse_gradient(*X_gpu_bundles, y_gpu, w, grad) + else: + datafit.gradient(X_gpu, y_gpu, w, grad) + grad_cpu = grad.copy_to_host() + + opt_crit = penalty.max_subdiff_distance(w_cpu, grad_cpu) + + print( + f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}" + ) + + # extrapolate + coef = (t_old - 1) / t_new + # mid_w = w + coef * (w - old_w) + _extrapolate[n_blocks_axis_1, MAX_1DIM_BLOCK]( + w, old_w, coef, mid_w) + + # update FISTA vars + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + # in `copy_to_device`: `self` is destination and `other` is source + old_w.copy_to_device(w) + + # transfer back to host + w_cpu = w.copy_to_host() + + return w_cpu + + +class QuadraticNumba(BaseQuadratic): + + def gradient(self, X_gpu, y_gpu, w, out): + minus_residual = cuda.device_array(X_gpu.shape[0]) + + n_blocks_axis_0, n_blocks_axis_1 = (math.ceil(n / MAX_1DIM_BLOCK[0]) + for n in X_gpu.shape) + + QuadraticNumba._compute_minus_residual[n_blocks_axis_0, MAX_1DIM_BLOCK]( + X_gpu, y_gpu, w, minus_residual) + + QuadraticNumba._compute_grad[n_blocks_axis_1, MAX_1DIM_BLOCK]( + X_gpu, minus_residual, out) + + def sparse_gradient(self, X_gpu_data, X_gpu_indptr, X_gpu_indices, X_gpu_shape, + y_gpu, w, out): + # init as zero as it is used in the computation + # otherwise stale values on GPU are used + minus_residual = cuda.to_device(np.zeros(X_gpu_shape[0])) + + n_blocks = math.ceil(X_gpu_shape[1] / MAX_1DIM_BLOCK[0]) + + QuadraticNumba._sparse_compute_minus_residual[n_blocks, MAX_1DIM_BLOCK]( + X_gpu_data, X_gpu_indptr, X_gpu_indices, X_gpu_shape, + y_gpu, w, minus_residual) + + QuadraticNumba._sparse_compute_grad[n_blocks, MAX_1DIM_BLOCK]( + X_gpu_data, X_gpu_indptr, X_gpu_indices, X_gpu_shape, + minus_residual, out) + + @staticmethod + @cuda.jit + def _compute_minus_residual(X_gpu, y_gpu, w, out): + # compute: out = X_gpu @ w - y_gpu + i = cuda.grid(1) + + n_samples, n_features = X_gpu.shape + stride_x = cuda.gridDim.x * cuda.blockDim.x + + for ii in range(i, n_samples, stride_x): + + # out[ii] = X_gpu[i, :] @ w - y_gpu + tmp = 0. + for j in range(n_features): + tmp += X_gpu[ii, j] * w[j] + tmp -= y_gpu[ii] + + out[ii] = tmp + + @staticmethod + @cuda.jit + def _compute_grad(X_gpu, minus_residual, out): + # compute: out = X.T @ minus_residual + j = cuda.grid(1) + + n_samples, n_features = X_gpu.shape + stride_y = cuda.gridDim.x * cuda.blockDim.x + + for jj in range(j, n_features, stride_y): + + # out[jj] = X_gpu[:, jj] @ minus_residual / n_samples + tmp = 0. + for i in range(n_samples): + tmp += X_gpu[i, jj] * minus_residual[i] / n_samples + + out[jj] = tmp + + @staticmethod + @cuda.jit + def _sparse_compute_minus_residual(X_gpu_data, X_gpu_indptr, X_gpu_indices, + X_gpu_shape, y_gpu, w, out): + j = cuda.grid(1) + + n_samples, n_features = X_gpu_shape + stride_y = cuda.gridDim.x * cuda.blockDim.x + + for jj in range(j, n_features, stride_y): + + # out -= y_gpu + # small hack to perform this operation using + # the (features) threads instead of launching others + for idx in range(jj, n_samples, n_features): + cuda.atomic.sub(out, idx, y_gpu[idx]) + + # out[i] = w[jj] * X_gpu[:, jj] + for idx in range(X_gpu_indptr[jj], X_gpu_indptr[jj+1]): + i = X_gpu_indices[idx] + cuda.atomic.add(out, i, w[jj] * X_gpu_data[idx]) + + @staticmethod + @cuda.jit + def _sparse_compute_grad(X_gpu_data, X_gpu_indptr, X_gpu_indices, X_gpu_shape, + minus_residual, out): + j = cuda.grid(1) + + n_samples, n_features = X_gpu_shape + stride_y = cuda.gridDim.x * cuda.blockDim.x + + for jj in range(j, n_features, stride_y): + + # out[jj] = X_gpu[:, jj] @ minus_residual / n_samples + tmp = 0. + for idx in range(X_gpu_indptr[jj], X_gpu_indptr[jj+1]): + i = X_gpu_indices[idx] + tmp += X_gpu_data[idx] * minus_residual[i] / n_samples + + out[jj] = tmp + + +class L1Numba(BaseL1): + + def prox(self, value, stepsize, out): + level = stepsize * self.alpha + + n_blocks = math.ceil(value.shape[0] / MAX_1DIM_BLOCK[0]) + + L1Numba._ST_vec[n_blocks, MAX_1DIM_BLOCK](value, level, out) + + @staticmethod + @cuda.jit + def _ST_vec(value, level, out): + j = cuda.grid(1) + + n_features = value.shape[0] + stride_y = cuda.gridDim.x * cuda.blockDim.x + + # out = ST(value, level) + for jj in range(j, n_features, stride_y): + value_j = value[jj] + + if abs(value_j) <= level: + value_j = 0. + elif value_j > level: + value_j = value_j - level + else: + value_j = value_j + level + + out[jj] = value_j + + +# solver kernels +@cuda.jit +def _forward(mid_w, grad, step, out): + j = cuda.grid(1) + + n_features = mid_w.shape[0] + stride_y = cuda.gridDim.x * cuda.blockDim.x + + # out = mid_w - step * grad + for jj in range(j, n_features, stride_y): + out[jj] = mid_w[jj] - step * grad[jj] + + +@cuda.jit +def _extrapolate(w, old_w, coef, out): + j = cuda.grid(1) + + n_features = w.shape[0] + stride_y = cuda.gridDim.x * cuda.blockDim.x + + # out = w + coef * (w - old_w) + for jj in range(j, n_features, stride_y): + out[jj] = w[jj] + coef * (w[jj] - old_w[jj]) diff --git a/skglm/gpu/solvers/pytorch_solver.py b/skglm/gpu/solvers/pytorch_solver.py new file mode 100644 index 000000000..341950396 --- /dev/null +++ b/skglm/gpu/solvers/pytorch_solver.py @@ -0,0 +1,123 @@ +import torch + +import numpy as np +from scipy import sparse + +from skglm.gpu.solvers.base import BaseFistaSolver, BaseQuadratic, BaseL1 + + +class PytorchSolver(BaseFistaSolver): + + def __init__(self, max_iter=1000, use_auto_diff=True, verbose=0): + self.max_iter = max_iter + self.use_auto_diff = use_auto_diff + self.verbose = verbose + + def solve(self, X, y, datafit, penalty): + n_samples, n_features = X.shape + X_is_sparse = sparse.issparse(X) + + if X_is_sparse and not self.use_auto_diff: + error_message = ( + "PyTorch doesn't support the operation `M.T @ vec`" + "for sparse matrices. Use `use_auto_diff=True`" + ) + + raise ValueError(error_message) + + # compute step + lipschitz = datafit.get_lipschitz_cst(X) + if lipschitz == 0.: + return np.zeros(n_features) + + step = 1 / lipschitz + + # transfer data + selected_device = torch.device("cuda") + if X_is_sparse: + X_gpu = torch.sparse_csc_tensor( + X.indptr, X.indices, X.data, X.shape, + dtype=torch.float64, + device=selected_device + ) + else: + X_gpu = torch.tensor(X, device=selected_device) + y_gpu = torch.tensor(y, device=selected_device) + + # init vars + w = torch.zeros(n_features, dtype=torch.float64, device=selected_device) + old_w = torch.zeros(n_features, dtype=torch.float64, device=selected_device) + mid_w = torch.zeros(n_features, dtype=torch.float64, device=selected_device, + requires_grad=self.use_auto_diff) + + t_old, t_new = 1, 1 + + for it in range(self.max_iter): + + # compute gradient + if self.use_auto_diff: + datafit_value = datafit.value(X_gpu, y_gpu, mid_w) + datafit_value.backward() + + grad = mid_w.grad + else: + grad = datafit.gradient(X_gpu, y_gpu, mid_w) + + # forward / backward + with torch.no_grad(): + w = penalty.prox(mid_w - step * grad, step) + + if self.verbose: + # transfer back to host + w_cpu = w.cpu().numpy() + + p_obj = datafit.value(X, y, w_cpu) + penalty.value(w_cpu) + + if self.use_auto_diff: + w_tmp = torch.tensor(w, dtype=torch.float64, + device=selected_device, requires_grad=True) + + datafit_value = datafit.value(X_gpu, y_gpu, w_tmp) + datafit_value.backward() + + grad_cpu = w_tmp.grad.detach().cpu().numpy() + else: + grad_cpu = datafit.gradient(X, y, w_cpu) + + opt_crit = penalty.max_subdiff_distance(w_cpu, grad_cpu) + + print( + f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}" + ) + + # extrapolate + mid_w = w + ((t_old - 1) / t_new) * (w - old_w) + mid_w = mid_w.requires_grad_(self.use_auto_diff) + + # update FISTA vars + t_old = t_new + t_new = 0.5 * (1 + np.sqrt(1. + 4. * t_old ** 2)) + # no need to copy `w` since its update (forward/backward) + # creates a new instance + old_w = w + + # transfer back to host + w_cpu = w.cpu().numpy() + + return w_cpu + + +class QuadraticPytorch(BaseQuadratic): + + def value(self, X, y, w): + return ((y - X @ w) ** 2).sum() / (2 * len(y)) + + def gradient(self, X, y, w): + return X.T @ (X @ w - y) / X.shape[0] + + +class L1Pytorch(BaseL1): + + def prox(self, value, stepsize): + shifted_value = torch.abs(value) - stepsize * self.alpha + return torch.sign(value) * torch.maximum(shifted_value, torch.tensor(0.)) diff --git a/skglm/gpu/tests/test_fista.py b/skglm/gpu/tests/test_fista.py new file mode 100644 index 000000000..b8d3afac0 --- /dev/null +++ b/skglm/gpu/tests/test_fista.py @@ -0,0 +1,67 @@ +import pytest + +from scipy import sparse + +import numpy as np +from numpy.linalg import norm + +from skglm.estimators import Lasso + +from skglm.gpu.solvers import CPUSolver +from skglm.gpu.solvers.base import BaseQuadratic, BaseL1 + +from skglm.gpu.solvers.jax_solver import JaxSolver, QuadraticJax, L1Jax +from skglm.gpu.solvers.cupy_solver import CupySolver, QuadraticCuPy, L1CuPy +from skglm.gpu.solvers.numba_solver import NumbaSolver, QuadraticNumba, L1Numba +from skglm.gpu.solvers.pytorch_solver import PytorchSolver, QuadraticPytorch, L1Pytorch + +from skglm.gpu.utils.host_utils import eval_opt_crit, compute_obj + + +@pytest.mark.parametrize("sparse_X", [True, False]) +@pytest.mark.parametrize( + "solver, datafit_cls, penalty_cls", + [ + (CPUSolver(), BaseQuadratic, BaseL1), + (CupySolver(), QuadraticCuPy, L1CuPy), + (PytorchSolver(use_auto_diff=True), QuadraticPytorch, L1Pytorch), + (PytorchSolver(use_auto_diff=False), QuadraticPytorch, L1Pytorch), + (JaxSolver(use_auto_diff=True), QuadraticJax, L1Jax), + (JaxSolver(use_auto_diff=False), QuadraticJax, L1Jax), + (NumbaSolver(), QuadraticNumba, L1Numba) + ]) +def test_solves(solver, datafit_cls, penalty_cls, sparse_X): + if (sparse_X and isinstance(solver, PytorchSolver) and not solver.use_auto_diff): + pytest.xfail(reason="PyTorch doesn't support `M.T @ vec` for sparse matrices") + + random_state = 1265 + n_samples, n_features = 100, 30 + reg = 1e-2 + + # generate dummy data + rng = np.random.RandomState(random_state) + if sparse_X: + X = sparse.rand(n_samples, n_features, density=0.1, + format="csc", random_state=rng) + else: + X = rng.randn(n_samples, n_features) + y = rng.randn(n_samples) + + # set lambda + lmbd_max = norm(X.T @ y, ord=np.inf) / n_samples + lmbd = reg * lmbd_max + + w = solver.solve(X, y, datafit_cls(), penalty_cls(lmbd)) + estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) + + stop_crit = eval_opt_crit(X, y, lmbd, w) + + np.testing.assert_allclose(stop_crit, 0., atol=1e-8) + np.testing.assert_allclose( + compute_obj(X, y, lmbd, estimator.coef_), + compute_obj(X, y, lmbd, w), + ) + + +if __name__ == "__main__": + pass diff --git a/skglm/gpu/tests/test_utils.py b/skglm/gpu/tests/test_utils.py new file mode 100644 index 000000000..3274f5358 --- /dev/null +++ b/skglm/gpu/tests/test_utils.py @@ -0,0 +1,42 @@ +import numpy as np +from skglm.gpu.utils.host_utils import compute_obj, eval_opt_crit +from skglm.gpu.utils.device_utils import get_device_properties + +from sklearn.linear_model import Lasso + + +def test_compute_obj(): + + # generate dummy data + X = np.eye(3) + y = np.array([1, 0, 1]) + w = np.array([1, 2, -3]) + lmbd = 10. + + p_obj = compute_obj(X, y, lmbd, w) + + np.testing.assert_array_equal(p_obj, 20 / (2 * 3) + 10. * 6) + + +def test_eval_optimality(): + rng = np.random.RandomState(1235) + n_samples, n_features = 10, 5 + + X = rng.randn(n_samples, n_features) + y = rng.randn(n_samples) + lmbd = 1. + + estimator = Lasso( + alpha=lmbd, fit_intercept=False, tol=1e-9 + ).fit(X, y) + + np.testing.assert_allclose( + eval_opt_crit(X, y, lmbd, estimator.coef_), 0., + atol=1e-9 + ) + + +def test_device_props(): + # check it runs and result is a dict + dev_props = get_device_properties() + np.testing.assert_equal(type(dev_props), dict) diff --git a/skglm/gpu/utils/__init__.py b/skglm/gpu/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skglm/gpu/utils/device_utils.py b/skglm/gpu/utils/device_utils.py new file mode 100644 index 000000000..ef6d331cd --- /dev/null +++ b/skglm/gpu/utils/device_utils.py @@ -0,0 +1,16 @@ +from numba import cuda +from numba.cuda.cudadrv import enums + + +# modified version of code in +# https://stackoverflow.com/questions/62457151/access-gpu-hardware-specifications-in-python # noqa +def get_device_properties(): + device = cuda.get_current_device() + + device_props_name = [name.replace("CU_DEVICE_ATTRIBUTE_", "") + for name in dir(enums) + if name.startswith("CU_DEVICE_ATTRIBUTE_")] + + device_props_value = [getattr(device, prop) for prop in device_props_name] + + return dict(zip(device_props_name, device_props_value)) diff --git a/skglm/gpu/utils/host_utils.py b/skglm/gpu/utils/host_utils.py new file mode 100644 index 000000000..62e7d0b11 --- /dev/null +++ b/skglm/gpu/utils/host_utils.py @@ -0,0 +1,33 @@ +import numpy as np +from numpy.linalg import norm + +from numba import njit + + +def compute_obj(X, y, lmbd, w): + return norm(y - X @ w) ** 2 / (2 * len(y)) + lmbd * norm(w, ord=1) + + +def eval_opt_crit(X, y, lmbd, w): + grad = X.T @ (X @ w - y) / len(y) + opt_crit = _compute_dist_subdiff(w, grad, lmbd) + + return opt_crit + + +@njit("f8(f8[:], f8[:], f8)") +def _compute_dist_subdiff(w, grad, lmbd): + max_dist = 0. + + for i in range(len(w)): + grad_i = grad[i] + w_i = w[i] + + if w[i] == 0.: + dist = max(abs(grad_i) - lmbd, 0.) + else: + dist = abs(grad_i + np.sign(w_i) * lmbd) + + max_dist = max(max_dist, dist) + + return max_dist