Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC - skglm GPU support #149

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
42f6736
FISTA CPU
Badr-MOUFAD Mar 12, 2023
116dda1
cupy solver
Badr-MOUFAD Mar 12, 2023
ff79040
unittest eval optimality condition
Badr-MOUFAD Mar 13, 2023
1385900
cleanups cpu solver
Badr-MOUFAD Mar 13, 2023
e63cdb7
jax solver
Badr-MOUFAD Mar 13, 2023
d49291c
add unittest jax solver
Badr-MOUFAD Mar 13, 2023
ec8c663
pass flake8
Badr-MOUFAD Mar 13, 2023
b703d25
numba solver layout
Badr-MOUFAD Mar 13, 2023
c82e352
numba cuda utils
Badr-MOUFAD Mar 13, 2023
8f57931
fix numba solver
Badr-MOUFAD Mar 13, 2023
68c0b71
unittest numba solver
Badr-MOUFAD Mar 13, 2023
0677519
numba solver example
Badr-MOUFAD Mar 14, 2023
cd686b3
move into solvers folder
Badr-MOUFAD Mar 14, 2023
dd7298a
add README to install
Badr-MOUFAD Mar 14, 2023
3b95086
fix conda env name
Badr-MOUFAD Mar 14, 2023
204c5e7
fix bug numba solver
Badr-MOUFAD Mar 15, 2023
b7ce4fe
update unittest & example
Badr-MOUFAD Mar 15, 2023
71819a0
fix bug init numba
Badr-MOUFAD Mar 15, 2023
b6372f9
base class for Fista solvers
Badr-MOUFAD Mar 17, 2023
97ceb4d
sparse matrix support CPU & CuPy
Badr-MOUFAD Mar 17, 2023
2ad2f2c
unittest sparse data
Badr-MOUFAD Mar 17, 2023
b8c753a
base quadratic and L1
Badr-MOUFAD Apr 10, 2023
872f9b3
refactor CPU solver
Badr-MOUFAD Apr 10, 2023
0aa15b3
test utils and fixes
Badr-MOUFAD Apr 10, 2023
14efe7b
unittest FISTA CPU
Badr-MOUFAD Apr 10, 2023
05a4f36
sparse data unittest
Badr-MOUFAD Apr 10, 2023
d23ac12
modular CuPy solver
Badr-MOUFAD Apr 10, 2023
601eb86
fix cupy verbose
Badr-MOUFAD Apr 10, 2023
761ab54
modular jax
Badr-MOUFAD Apr 10, 2023
6274e5f
unittest jax
Badr-MOUFAD Apr 10, 2023
3af102e
sparse matrices modular jax
Badr-MOUFAD Apr 11, 2023
715d3fb
modular Numba solver
Badr-MOUFAD Apr 12, 2023
34cc4a8
unittest numba && dev utils
Badr-MOUFAD Apr 12, 2023
b6f971c
comments && prob formula
Badr-MOUFAD Apr 12, 2023
a7d1375
Numba with shared memory
Badr-MOUFAD Apr 12, 2023
c786a12
Numba shared memory version
Badr-MOUFAD Apr 13, 2023
2c4cd63
kernels as static methods && Numba fix tests
Badr-MOUFAD Apr 13, 2023
e3ac70a
sparse Numba solver
Badr-MOUFAD Apr 13, 2023
ce4367d
fix bug numba gradient
Badr-MOUFAD Apr 13, 2023
c8fd8a1
fix bug numba sparse residual
Badr-MOUFAD Apr 14, 2023
a6df22a
n_samples instead of shape
Badr-MOUFAD Apr 14, 2023
cf5dc9e
Numba_solver: striding for scalable kernels
Badr-MOUFAD Apr 14, 2023
20f9274
Numba_L1: striding for scalable kernels
Badr-MOUFAD Apr 14, 2023
d8d7157
Numba sparse datafit: striding
Badr-MOUFAD Apr 14, 2023
4e4e6c1
Numba dense datafit: striding
Badr-MOUFAD Apr 14, 2023
1d07d9e
info comments Numba solver
Badr-MOUFAD Apr 14, 2023
ca9f694
update installation && normalize df and pen cupy
Badr-MOUFAD Apr 14, 2023
324cac5
pytorch solver [buggy]
Badr-MOUFAD Apr 14, 2023
c5c1dfe
fix grad bug pytorch solver && unittest
Badr-MOUFAD Apr 14, 2023
32f1014
pytorch solver sparse data
Badr-MOUFAD Apr 16, 2023
0caa9f9
set order between jax pytorch && xfail sparse and auto_diff false
Badr-MOUFAD Apr 16, 2023
545a27f
test on obj value
Badr-MOUFAD Apr 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions skglm/gpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
## Installation

1. checkout branch
```shell
# add remote if it does't exist (check with: git remote -v)
git remote add Badr-MOUFAD https://github.com/Badr-MOUFAD/skglm.git

git fetch Badr-MOUFAD skglm-gpu

git checkout skglm-gpu
```

2. create then activate``conda`` environnement
```shell
# create
conda create -n skglm-gpu python=3.7

# activate env
conda activate skglm-gpu
```

3. install ``skglm`` in editable mode
```shell
pip install skglm -e .
```

4. install dependencies
```shell
# cupy
conda install -c conda-forge cupy cudatoolkit=11.5

# pytorch
pip install torch

# jax
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
```
6 changes: 6 additions & 0 deletions skglm/gpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Solve Lasso problem using FISTA GPU-implementation.

Problem reads::

min_w (1/2n) * ||y - Xw||^2 + lmbd * ||w||_1
"""
66 changes: 66 additions & 0 deletions skglm/gpu/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import time
import warnings

import numpy as np
from numpy.linalg import norm

from benchopt.datasets import make_correlated_data

from skglm.gpu.solvers import NumbaSolver, CPUSolver

from skglm.gpu.utils.host_utils import compute_obj, eval_opt_crit

from numba.core.errors import NumbaPerformanceWarning
warnings.filterwarnings('ignore', category=NumbaPerformanceWarning)


random_state = 27
n_samples, n_features = 10_000, 500
reg = 1e-2

# generate dummy data
rng = np.random.RandomState(random_state)

X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng)


# set lambda
lmbd_max = norm(X.T @ y, ord=np.inf)
lmbd = reg * lmbd_max

max_iter = 0

# cache numba compilation
NumbaSolver(verbose=0, max_iter=2).solve(X, y, lmbd)

solver_gpu = NumbaSolver(verbose=1, max_iter=max_iter)
# solve problem
start = time.perf_counter()
w_gpu = solver_gpu.solve(X, y, lmbd)
end = time.perf_counter()

print("gpu time: ", end - start)


# cache numba compilation
CPUSolver(max_iter=2).solve(X, y, lmbd)

solver_cpu = CPUSolver(verbose=1, max_iter=max_iter)
start = time.perf_counter()
w_cpu = solver_cpu.solve(X, y, lmbd)
end = time.perf_counter()
print("cpu time: ", end - start)


print(
"Objective\n"
f"gpu : {compute_obj(X, y, lmbd, w_gpu):.8f}\n"
f"cpu : {compute_obj(X, y, lmbd, w_cpu):.8f}"
)


print(
"Optimality condition\n"
f"gpu : {eval_opt_crit(X, y, lmbd, w_gpu):.8f}\n"
f"cpu : {eval_opt_crit(X, y, lmbd, w_cpu):.8f}"
)
4 changes: 4 additions & 0 deletions skglm/gpu/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from skglm.gpu.solvers.cpu_solver import CPUSolver # noqa
from skglm.gpu.solvers.cupy_solver import CupySolver # noqa
from skglm.gpu.solvers.jax_solver import JaxSolver # noqa
from skglm.gpu.solvers.numba_solver import NumbaSolver # noqa
66 changes: 66 additions & 0 deletions skglm/gpu/solvers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from numba import njit
from abc import abstractmethod

import numpy as np
from scipy import sparse
from scipy.sparse import linalg as spicy_linalg

from skglm.utils.prox_funcs import ST_vec


class BaseFistaSolver:

@abstractmethod
def solve(self, X, y, datafit, penalty):
...


class BaseQuadratic:

def value(self, X, y, w, Xw):
"""parameters are numpy/scipy arrays."""
return ((y - X @ w) ** 2).sum() / (2 * len(y))

def gradient(self, X, y, w, Xw):
return X.T @ (Xw - y) / len(y)

def get_lipschitz_cst(self, X):
n_samples = X.shape[0]

if sparse.issparse(X):
return spicy_linalg.svds(X, k=1)[1][0] ** 2 / n_samples

return np.linalg.norm(X, ord=2) ** 2 / n_samples


class BaseL1:

def __init__(self, alpha):
self.alpha = alpha

def value(self, w):
return self.alpha * np.abs(w).sum()

def prox(self, value, stepsize):
return ST_vec(value, self.alpha * stepsize)

def max_subdiff_distance(self, w, grad):
return BaseL1._compute_max_subdiff_distance(w, grad, self.alpha)

@staticmethod
@njit("f8(f8[:], f8[:], f8)")
def _compute_max_subdiff_distance(w, grad, lmbd):
max_dist = 0.

for i in range(len(w)):
grad_i = grad[i]
w_i = w[i]

if w[i] == 0.:
dist = max(abs(grad_i) - lmbd, 0.)
else:
dist = abs(grad_i + np.sign(w_i) * lmbd)

max_dist = max(max_dist, dist)

return max_dist
55 changes: 55 additions & 0 deletions skglm/gpu/solvers/cpu_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
from skglm.gpu.solvers.base import BaseFistaSolver


class CPUSolver(BaseFistaSolver):

def __init__(self, max_iter=1000, verbose=0):
self.max_iter = max_iter
self.verbose = verbose

def solve(self, X, y, datafit, penalty):
n_samples, n_features = X.shape

# compute step
lipschitz = datafit.get_lipschitz_cst(X)
if lipschitz == 0.:
return np.zeros(n_features)

step = 1 / lipschitz

# init vars
w = np.zeros(n_features)
old_w = np.zeros(n_features)
mid_w = np.zeros(n_features)
grad = np.zeros(n_features)

t_old, t_new = 1, 1

for it in range(self.max_iter):

# compute grad
grad = datafit.gradient(X, y, mid_w, X @ mid_w)

# forward / backward
w = penalty.prox(mid_w - step * grad, step)

if self.verbose:
p_obj = datafit.value(X, y, w, X @ w) + penalty.value(w)

grad = datafit.gradient(X, y, w, X @ w)
opt_crit = penalty.max_subdiff_distance(w, grad)

print(
f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}"
)

# extrapolate
mid_w = w + ((t_old - 1) / t_new) * (w - old_w)

# update FISTA vars
t_old = t_new
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
old_w = np.copy(w)

return w
81 changes: 81 additions & 0 deletions skglm/gpu/solvers/cupy_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import cupy as cp
import cupyx.scipy.sparse as cpx

import numpy as np
from scipy import sparse

from skglm.gpu.solvers.base import BaseFistaSolver, BaseL1, BaseQuadratic


class CupySolver(BaseFistaSolver):

def __init__(self, max_iter=1000, verbose=0):
self.max_iter = max_iter
self.verbose = verbose

def solve(self, X, y, datafit, penalty):
n_samples, n_features = X.shape

# compute step
lipschitz = datafit.get_lipschitz_cst(X)
if lipschitz == 0.:
return np.zeros(n_features)

step = 1 / lipschitz

is_X_sparse = sparse.issparse(X)

# transfer to device
X_gpu = cp.array(X) if not is_X_sparse else cpx.csr_matrix(X)
y_gpu = cp.array(y)

# init vars in device
w = cp.zeros(n_features)
old_w = cp.zeros(n_features)
mid_w = cp.zeros(n_features)
grad = cp.zeros(n_features)

t_old, t_new = 1, 1

for it in range(self.max_iter):

# compute grad
grad = datafit.gradient(X_gpu, y_gpu, mid_w, X_gpu @ mid_w)

# forward / backward
val = mid_w - step * grad
w = penalty.prox(val, step)

if self.verbose:
p_obj = datafit.value(X_gpu, y_gpu, w, X_gpu @ w) + penalty.value(w)

w_cpu = cp.asnumpy(w)
grad = datafit.gradient(X, y, w_cpu, X @ w_cpu)
opt_crit = penalty.max_subdiff_distance(w_cpu, grad)

print(
f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}"
)

# extrapolate
mid_w = w + ((t_old - 1) / t_new) * (w - old_w)

# update FISTA vars
t_old = t_new
t_new = (1 + cp.sqrt(1 + 4 * t_old ** 2)) / 2
old_w = cp.copy(w)

# transfer back to host
w_cpu = cp.asnumpy(w)

return w_cpu


class QuadraticCuPy(BaseQuadratic):
pass


class L1CuPy(BaseL1):

def prox(self, value, stepsize):
return cp.sign(value) * cp.maximum(cp.abs(value) - stepsize * self.alpha, 0.)
Loading