Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linsys solver using hipsparse #293

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ src/scs_version.o: src/scs_version.c $(INC_FILES)
$(DIRSRC)/private.o: $(DIRSRC)/private.c $(DIRSRC)/private.h
$(INDIRSRC)/indirect/private.o: $(INDIRSRC)/private.c $(INDIRSRC)/private.h
$(MKLSRC)/private.o: $(MKLSRC)/private.c $(MKLSRC)/private.h
$(HIPSRC)/private.o: $(HIPSRC)/private.c $(HIPSRC)/private.h
$(HIPCC) $(CFLAGS) $(HIPCFLAGS) -I$(HIPSRC) -c $(HIPSRC)/private.c -o $@
$(LINSYS)/scs_matrix.o: $(LINSYS)/scs_matrix.c $(LINSYS)/scs_matrix.h
$(LINSYS)/csparse.o: $(LINSYS)/csparse.c $(LINSYS)/csparse.h

Expand All @@ -69,6 +71,11 @@ $(OUT)/libscsmkl.a: $(SCS_O) $(SCS_OBJECTS) $(MKLSRC)/private.o $(LINSYS)/scs_ma
$(ARCHIVE) $@ $^
- $(RANLIB) $@

$(OUT)/libscship.a: $(SCS_O) $(SCS_OBJECTS) $(HIPSRC)/private.o $(LINSYS)/scs_matrix.o $(LINSYS)/csparse.o
mkdir -p $(OUT)
$(ARCHIVE) $@ $^
- $(RANLIB) $@

$(OUT)/libscsdir.$(SHARED): $(SCS_O) $(SCS_OBJECTS) $(DIRSRC)/private.o $(AMD_OBJS) $(LDL_OBJS) $(LINSYS)/scs_matrix.o $(LINSYS)/csparse.o
mkdir -p $(OUT)
$(CC) $(CFLAGS) -shared -Wl,$(SONAME),$(@:$(OUT)/%=%) -o $@ $^ $(LDFLAGS) $(BLASLDFLAGS)
Expand All @@ -81,6 +88,10 @@ $(OUT)/libscsmkl.$(SHARED): $(SCS_O) $(SCS_OBJECTS) $(MKLSRC)/private.o $(LINSYS
mkdir -p $(OUT)
$(CC) $(CFLAGS) -shared -Wl,$(SONAME),$(@:$(OUT)/%=%) -o $@ $^ $(LDFLAGS) $(MKLFLAGS)

$(OUT)/libscship.$(SHARED): $(SCS_O) $(SCS_OBJECTS) $(HIPSRC)/private.o $(LINSYS)/scs_matrix.o $(LINSYS)/csparse.o
mkdir -p $(OUT)
$(CC) $(CFLAGS) -shared -Wl,$(SONAME),$(@:$(OUT)/%=%) -o $@ $^ $(LDFLAGS) $(HIPLDFLAGS)

$(OUT)/demo_socp_direct: test/random_socp_prob.c $(OUT)/libscsdir.a
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(BLASLDFLAGS)

Expand All @@ -90,6 +101,9 @@ $(OUT)/demo_socp_indirect: test/random_socp_prob.c $(OUT)/libscsindir.a
$(OUT)/demo_socp_mkl: test/random_socp_prob.c $(OUT)/libscsmkl.a
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(MKLFLAGS)

$(OUT)/demo_socp_hip: test/random_socp_prob.c $(OUT)/libscship.a
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(BLASLDFLAGS) $(HIPLDFLAGS)

$(OUT)/run_from_file_direct: test/run_from_file.c $(OUT)/libscsdir.a
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(BLASLDFLAGS)

Expand All @@ -108,7 +122,8 @@ $(OUT)/run_tests_direct: test/run_tests.c $(OUT)/libscsdir.a
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(BLASLDFLAGS) -Itest
$(OUT)/run_tests_mkl: test/run_tests.c $(OUT)/libscsmkl.a
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(MKLFLAGS) -Itest

$(OUT)/run_tests_hip: test/run_tests.c $(OUT)/libscship.a
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(BLASLDFLAGS) $(HIPLDFLAGS) -Itest

.PHONY: test_gpu
test_gpu: $(OUT)/run_tests_gpu_indirect # $(OUT)/run_tests_gpu_direct
Expand All @@ -120,6 +135,8 @@ ifndef MKLROOT
$(error MKLROOT is undefined, set MKLROOT to the MKL install location)
endif

.PHONY:
hip: $(OUT)/libscship.a $(OUT)/libscship.$(SHARED) $(OUT)/run_tests_hip $(OUT)/demo_socp_hip

$(OUT)/run_tests_gpu_indirect: test/run_tests.c $(OUT)/libscsgpuindir.a
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(BLASLDFLAGS) $(CULDFLAGS) -Itest
Expand Down
254 changes: 254 additions & 0 deletions linsys/hip/direct/private.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
#include "private.h"

#include <hip/hip_runtime.h>
#include <hipsparse/hipsparse.h>

const hipsparseOperation_t NON_TRANSP = HIPSPARSE_OPERATION_NON_TRANSPOSE;
const hipsparseSolvePolicy_t NO_LVL_POLICY = HIPSPARSE_SOLVE_POLICY_NO_LEVEL;

const char *scs_get_lin_sys_method()
{
return "hipsparse-direct";
}

void scs_free_lin_sys_work(ScsLinSysWork *work)
{
if (work == NULL)
return;

// Free device memory
if (work->d_vals)
hipFree(work->d_vals);
if (work->d_row_ptrs)
hipFree(work->d_row_ptrs);
if (work->d_col_inds)
hipFree(work->d_col_inds);
if (work->d_b)
hipFree(work->d_b);
if (work->d_x)
hipFree(work->d_x);
if (work->buffer)
hipFree(work->buffer);

// Free LU decomposition info
if (work->info_LU)
hipsparseDestroyCsrsv2Info(work->info_LU);

// Free HIPSPARSE matrix descriptor and handle
if (work->descr)
hipsparseDestroyMatDescr(work->descr);
if (work->handle)
hipsparseDestroy(work->handle);

// Free the matrix kkt data
if (work->kkt)
(SCS(cs_spfree)(work->kkt));

// Free host-side arrays used for updates
if (work->diag_r_idxs)
scs_free(work->diag_r_idxs);
if (work->diag_p)
scs_free(work->diag_p);

// Finally, free the work struct itself
scs_free(work);
}

hipsparseStatus_t __initialize_work(ScsLinSysWork *work)
{
ScsMatrix *A = work->kkt;
hipsparseStatus_t status;

// Initialize matrix descriptor
status = hipsparseCreateMatDescr(&(work->descr));
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- descriptor: %d.\n", (int)status);
}
hipsparseSetMatIndexBase(work->descr, HIPSPARSE_INDEX_BASE_ZERO); // Zero-based indexing
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- index 0: %d.\n", (int)status);
}
status = hipsparseSetMatType(work->descr, HIPSPARSE_MATRIX_TYPE_SYMMETRIC); // Symmetric matrix
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- type symmetric: %d.\n", (int)status);
}
status = hipsparseSetMatFillMode(work->descr, HIPSPARSE_FILL_MODE_UPPER); // stored in upper-diagonal part
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- fill upper: %d.\n", (int)status);
}
status = hipsparseSetMatDiagType(work->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT); // with non-unit diagonal
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- diagonal non-unit: %d.\n", (int)status);
}

// Compute number of non-zeros
int nnz = A->p[A->n]; // The last element of A->p gives the number of non-zeros

// Allocate memory on device
hipMalloc(&(work->d_vals), nnz * sizeof(scs_float)); // Matrix values (number of non-zeros)
hipMalloc(&(work->d_row_ptrs), (A->n + 1) * sizeof(scs_int)); // Column pointers (size n + 1)
hipMalloc(&(work->d_col_inds), nnz * sizeof(scs_int)); // Row indices (number of non-zeros)

// Preallocate memory for vectors b and x (for solving Ax = b)
hipMalloc(&(work->d_b), A->m * sizeof(scs_float)); // RHS vector b
hipMalloc(&(work->d_x), A->m * sizeof(scs_float)); // Solution vector x

// Copy matrix to device
hipMemcpy(work->d_vals, A->x, nnz * sizeof(scs_float), hipMemcpyHostToDevice);
hipMemcpy(work->d_row_ptrs, A->p, (A->n + 1) * sizeof(scs_int), hipMemcpyHostToDevice);
hipMemcpy(work->d_col_inds, A->i, nnz * sizeof(scs_int), hipMemcpyHostToDevice);

// Initialize HIPSPARSE
status = hipsparseCreate(&(work->handle));
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- handle: %d.\n", (int)status);
}
// Create info object for LU decomposition
status = hipsparseCreateCsrsv2Info(&(work->info_LU));
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- create info: %d.\n", (int)status);
}

// Analyze step to get buffer size
status = hipsparseDcsrsv2_bufferSize(work->handle, NON_TRANSP, A->m, nnz, work->descr,
work->d_vals, work->d_row_ptrs, work->d_col_inds,
work->info_LU, &(work->bufferSize));

if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- bufferSize: %d.\n", (int)status);
}

hipMalloc(&(work->buffer), work->bufferSize);

// Perform symbolic factorization for LU
status = hipsparseDcsrsv2_analysis(work->handle, NON_TRANSP, A->m, nnz, work->descr,
work->d_vals, work->d_row_ptrs, work->d_col_inds,
work->info_LU, NO_LVL_POLICY, work->buffer);

if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- analysis: %d.\n", (int)status);
}

// Perform numeric factorization (LU decomposition)
scs_float alpha = 1.0;
status = hipsparseDcsrsv2_solve(work->handle, NON_TRANSP, A->m, nnz, &alpha, work->descr,
work->d_vals, work->d_row_ptrs, work->d_col_inds,
work->info_LU, NULL, NULL, NO_LVL_POLICY, work->buffer);
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in init -- numeric fact: %d.\n", (int)status);
}
return status;
}

ScsLinSysWork *scs_init_lin_sys_work(const ScsMatrix *A, const ScsMatrix *P,
const scs_float *diag_r)
{
ScsLinSysWork *p = scs_calloc(1, sizeof(ScsLinSysWork));

p->n = A->n;
p->m = A->m;
scs_int n_plus_m = p->n + p->m;

p->diag_r_idxs = (scs_int *)scs_calloc(n_plus_m, sizeof(scs_int));
p->diag_p = (scs_float *)scs_calloc(p->n, sizeof(scs_float));

// p->kkt is CSC in lower triangular form; this is equivalen to upper CSR
p->kkt = SCS(form_kkt)(A, P, p->diag_p, diag_r, p->diag_r_idxs, 0);
if (!(p->kkt))
{
scs_printf("Error in forming KKT matrix");
scs_free_lin_sys_work(p);
return SCS_NULL;
}

hipsparseStatus_t status;
status = __initialize_work(p);

if (status == HIPSPARSE_STATUS_SUCCESS)
{
return p;
}
else
{
scs_printf("Error in factorisation: %d.\n", (int)status);
scs_free_lin_sys_work(p);
return SCS_NULL;
}
}

/* Returns solution to linear system Ax = b with solution stored in b */
scs_int scs_solve_lin_sys(ScsLinSysWork *work, scs_float *b, const scs_float *ws,
scs_float tol)
{
if (work == NULL || b == NULL || ws == NULL)
{
return -1; // Error: invalid input
}

// Copy warmstart solution to device
hipMemcpy(work->d_x, ws, work->kkt->m * sizeof(scs_float), hipMemcpyHostToDevice);

// Copy b to device
hipMemcpy(work->d_b, b, work->kkt->m * sizeof(scs_float), hipMemcpyHostToDevice);

// Solve the system Ax = b using LU decomposition
hipsparseStatus_t status;
scs_float alpha = 1.0;
scs_int nnz = work->kkt->p[work->kkt->n];
status = hipsparseDcsrsv2_solve(work->handle, NON_TRANSP, work->kkt->m, nnz, &alpha, work->descr,
work->d_vals, work->d_row_ptrs, work->d_col_inds,
work->info_LU, work->d_x, work->d_b, NO_LVL_POLICY, work->buffer);
if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error during linear system solution: %d.\n", (int)status);
}

// Copy the solution back to the host
hipMemcpy(b, work->d_x, work->kkt->m * sizeof(scs_float), hipMemcpyDeviceToHost);

return (scs_int)status;
}

/* Update factorization when R changes */
void scs_update_lin_sys_diag_r(ScsLinSysWork *p, const scs_float *diag_r)
{
scs_int i;

for (i = 0; i < p->n; ++i)
{
/* top left is R_x + P, bottom right is -R_y */
p->kkt->x[p->diag_r_idxs[i]] = p->diag_p[i] + diag_r[i];
}
for (i = p->n; i < p->n + p->m; ++i)
{
/* top left is R_x + P, bottom right is -R_y */
p->kkt->x[p->diag_r_idxs[i]] = -diag_r[i];
}

scs_int nnz = p->kkt->p[p->kkt->n];
hipMemcpy(p->d_vals, p->kkt->x, nnz * sizeof(scs_float), hipMemcpyHostToDevice);

// Perform numeric factorization (LU decomposition) after changes
hipsparseStatus_t status;
scs_float alpha = 1.0;
status = hipsparseDcsrsv2_solve(p->handle, NON_TRANSP, p->kkt->m, nnz, &alpha, p->descr,
p->d_vals, p->d_row_ptrs, p->d_col_inds,
p->info_LU, NULL, NULL, NO_LVL_POLICY, p->buffer);

if (status != HIPSPARSE_STATUS_SUCCESS)
{
scs_printf("Error in factorization when updating: %d.\n",
(int)status);
scs_free_lin_sys_work(p);
}
}
49 changes: 49 additions & 0 deletions linsys/hip/direct/private.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef PRIV_H_GUARD
#define PRIV_H_GUARD

#ifdef __cplusplus
extern "C"
{
#endif

#include "csparse.h"
#include "linsys.h"
#include <hip/hip_runtime.h>
#include <hipsparse/hipsparse.h>

struct SCS_LIN_SYS_WORK
{
// Host:
ScsMatrix *kkt; /* Upper triangular KKT matrix (in CSR format) */
scs_int n; /* number of QP variables */
scs_int m; /* number of QP constraints */

hipsparseHandle_t handle; // HIPSPARSE handle
hipsparseMatDescr_t descr; // Matrix descriptor

// kkt matrix data on the device
scs_float *d_vals; // Non-zero values of the matrix (on device)
scs_int *d_row_ptrs; // Row pointers (on device)
scs_int *d_col_inds; // Column indices (on device)

// Vectors for solving system Ax = b
scs_float *d_b; // RHS vector b (on device)
scs_float *d_x; // Solution vector x (on device)

// LU decomposition info for the lower triangular matrix
csrsv2Info_t info_LU; // Lower triangular solve info

// Buffer for LU decomposition and solving
void *buffer; // Buffer for LU factorization and solving
scs_int bufferSize; // Size of the buffer

/* These are required for matrix updates */
scs_int *diag_r_idxs; /* indices where R appears */
scs_float *diag_p; /* Diagonal values of P */
};

#ifdef __cplusplus
}
#endif

#endif
7 changes: 7 additions & 0 deletions scs.mk
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ endif
#CC = i686-w64-mingw32-gcc -m32
#CC = x86_64-w64-mingw32-gcc-4.8
CUCC = $(CC) #Don't need to use nvcc, since using cuda blas APIs
HIPCC = $(HIP_PATH)/bin/hipcc

# For GPU must add cuda libs to path, e.g.
# export DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH
Expand Down Expand Up @@ -70,6 +71,7 @@ INDIRSRC = $(LINSYS)/cpu/indirect
GPUDIR = $(LINSYS)/gpu/direct
GPUINDIR = $(LINSYS)/gpu/indirect
MKLSRC = $(LINSYS)/mkl/direct
HIPSRC = $(LINSYS)/hip/direct

EXTSRC = $(LINSYS)/external

Expand Down Expand Up @@ -135,6 +137,11 @@ endif
# to work for all combinations of platform / compiler / threading options.
MKLFLAGS = -L$(MKLROOT) -L$(MKLROOT)/lib -Wl,--no-as-needed -lmkl_rt -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -ldl

HIP_PLATFORM =HIP_PLATFORM_AMD
HIP_PATH = /opt/rocm
HIPLDFLAGS = -L$(HIP_PATH)/lib -lamdhip64 -lhipsparse
HIPCFLAGS = -D__$(HIP_PLATFORM)__ -I$(HIP_PATH)/include -Wno-extra-semi -Wno-strict-prototypes

############ OPENMP: ############
# set USE_OPENMP = 1 to allow openmp (multi-threaded matrix multiplies):
# set the number of threads to, for example, 4 by entering the command:
Expand Down
Loading