You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm working on solving a matrix equation that seems to be ill-conditioned, and I'm encountering convergence issues with all the configuration files provided in AMGX. I've attached a JPEG image of the matrix to give an idea of its structure.
Despite confirming that the matrix rank is appropriate, I'm struggling with the equation's principal axis, which contains many zero values. This characteristic might be contributing to the convergence problems.
I'm seeking advice on how to properly configure AMGX for this type of equation. To provide more context, I've also attached a snippet of my code that illustrates how I'm constructing and attempting to solve this system.
Any suggestions or guidance on the appropriate AMGX configuration settings for this scenario would be greatly appreciated.
Thank you for your time and help.
//// below is code
/* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of NVIDIA CORPORATION nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <mpi.h>
#include "cuda_runtime.h"
#include <stdint.h>
#define M_PI 3.14159265358979323846
/* CUDA error macro */
#define CUDA_SAFE_CALL(call) do { \
cudaError_t err = call; \
if(cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
} } while (0)
//#define AMGX_DYNAMIC_LOADING
//#undef AMGX_DYNAMIC_LOADING
#define MAX_MSG_LEN 4096
/* standard or dynamically load library */
#ifdef AMGX_DYNAMIC_LOADING
#include "amgx_capi.h"
#else
#include "amgx_c.h"
#endif
/* print error message and exit */
void errAndExit(const char *err)
{
printf("%s\n", err);
fflush(stdout);
MPI_Abort(MPI_COMM_WORLD, 1);
exit(1);
}
/* print callback (could be customized) */
void print_callback(const char *msg, int length)
{
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
if (rank == 0) { printf("%s", msg); }
}
/* print usage and exit */
void printUsageAndExit()
{
char msg[MAX_MSG_LEN] = "Usage: mpirun [-n nranks] ./example [-mode [dDDI | dDFI | dFFI]] [-p nx ny] [-c config_file] [-amg \"variable1=value1 ... variable3=value3\"] [-gpu] [-it k]\n";
strcat(msg, " -mode: select the solver mode\n");
strcat(msg, " -p nx ny: select x- and y-dimensions of the 2D (5-points) local discretization of the Poisson operator (the global problem size will be nranks*nx*ny)\n");
strcat(msg, " -c: set the amg solver options from the config file\n");
strcat(msg, " -amg: set the amg solver options from the command line\n");
print_callback(msg, MAX_MSG_LEN);
MPI_Finalize();
exit(0);
}
/* parse parameters */
int findParamIndex(char **argv, int argc, const char *parm)
{
int count = 0;
int index = -1;
for (int i = 0; i < argc; i++)
{
if (strncmp(argv[i], parm, 100) == 0)
{
index = i;
count++;
}
}
if (count == 0 || count == 1)
{
return index;
}
else
{
char msg[MAX_MSG_LEN];
sprintf(msg, "ERROR: parameter %s has been specified more than once, exiting\n", parm);
print_callback(msg, MAX_MSG_LEN);
exit(1);
}
return -1;
}
void continuityXdirKernel(
int n,
int nx,
int ny,
double dx,
int* ndd,
double* u,
double* v,
double* h,
double* detadt,
void* values,
void* rhs,
int* row_ptrs) {
/* -------------------- continuity equation ---------------------------
-----------------------detadt = dhu/dx + dhv/dy------------------------
detadt is constant, so it is not included in the matrix. In other words,
the pivot element is 0 for the continuity equation.
-------------------------------------------------------------------- */
for (int i = 0; i < n; i++) {
// note that here n, nx, ny is different from those of momentum equation
int row = i / nx; // row index in the 2D work array
int col = i % nx; // col index in the 1D work array
int idx = 2 * col + 1 + (2 * nx + 1) * row; // pivot index in the row of sparse matrix
int nnz = row_ptrs[idx]; // index of the first non-zero element in the row
// open boundary condition for water level
if (ndd[i] == 2) {
((double*)values)[nnz] = 1.0;
((double*)rhs)[idx] = 0.0;
}
// in other cases, just discrete the equation
else {
// left u coefficient
((double*)values)[nnz] = -0.5 * (h[i] + h[i - 1]) / dx;
// right u coefficient
((double*)values)[nnz + 1] = 0.5 * (h[i] + h[min(i + 1, nx - 1)]) / dx;
//rhs values
((double*)rhs)[idx] = detadt[i];
}
}
}
void momentumXdirKernel(int n,
int nx,
int ny,
double dx,
double gra,
double mann,
double Du,
int* ndd,
double* u,
double* v,
double* eta,
double* h,
void* values,
void* rhs,
int* row_ptrs) {
/* -------------------- momentum equation u ---------------------------
* 0 = g * mann^2 / h^(4/3) * |uv| * u // friction term
+ udu/dx + vdu/dy // advection term
g * deta/dx // gravity term
- d/dx (Du * du/dx) // diffusion term
-------------------------------------------------------------------- */
//int i = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = 0; i < n; i++) {
// note that here n, nx is different from those of continuity equation
int row = i / nx; // row index in the 2D work array
int col = i % nx; // col index in the 1D work array
int idx = 2 * col + (2 * nx - 1) * row; // pivot index in the row of sparse matrix
int nnz = row_ptrs[idx]; // index of the first non-zero element in the row
// inner points
if (col > 0 && col < nx - 1) {
int i0 = col + row * (nx - 1);
double h_u_c = 0.5 * (h[i0] + h[i0 - 1]); // h_u_c = h_u_center
double D_h_l = Du;
double D_h_r = Du;
// left U
((double*)values)[nnz] = -fmax(0., u[i]) / dx // upwind advection term
- D_h_l / dx / dx; // diffusion term
// left eta
((double*)values)[nnz + 1] = -gra / dx; // gravity term
// center U
((double*)values)[nnz + 2] = gra * pow(mann, 2) / pow(h_u_c, 4.0 / 3.0) * fabs(u[i]) // friction term
+ (fmax(0., u[i]) - fmin(0., u[i])) / dx // upwind advection term
+ (D_h_l + D_h_r) / dx / dx; // diffusion term
// right eta
((double*)values)[nnz + 3] = gra / dx; // gravity term
// right U
((double*)values)[nnz + 4] = fmin(0., u[i]) / dx // upwind advection term
- D_h_r / dx / dx; // diffusion term
//rhs values
((double*)rhs)[idx] = 0.;
}
// left boundary
else if (col == 0) {
int i0 = col + row * (nx - 1);
// open boundary condition
if (ndd[i0] == 2) {
// center U
((double*)values)[nnz] = 1.0;
// right U
((double*)values)[nnz + 1] = -2.0;
// right right U
((double*)values)[nnz + 2] = 1.0;
//rhs values
((double*)rhs)[idx] = 0.;
}
// closed boundary condition
else {
((double*)values)[nnz] = 1.0;
((double*)rhs)[idx] = 0.;
}
}
// right boundary
else if (col == nx - 1) {
int i0 = col - 1 + row * (nx - 1);
// open boundary condition
if (ndd[i0] == 2) {
// left left U
((double*)values)[nnz] = 1.0;
// left U
((double*)values)[nnz + 1] = -2.0;
// center U
((double*)values)[nnz + 2] = 1.0;
//rhs values
((double*)rhs)[idx] = 0.;
}
// closed boundary condition
else {
((double*)values)[nnz] = 1.0;
((double*)rhs)[idx] = 0.;
}
}
}
}
int main(int argc, char **argv)
{
//parameter parsing
int pidx = 0;
int pidy = 0;
//MPI (with CUDA GPUs)
int rank = 0;
int lrank = 0;
int nranks = 0;
int n;
int nx, ny;
int gpu_count = 0;
MPI_Comm amgx_mpi_comm = MPI_COMM_WORLD;
//versions
int major, minor;
char *ver, *date, *time;
//input matrix and rhs/solution
int *partition_sizes = NULL;
int *partition_vector = NULL;
int partition_vector_size = 0;
//library handles
AMGX_Mode mode;
AMGX_config_handle cfg;
AMGX_resources_handle rsrc;
AMGX_matrix_handle A;
AMGX_vector_handle b, x;
AMGX_solver_handle solver;
//status handling
AMGX_SOLVE_STATUS status;
/* MPI init (with CUDA GPUs) */
//MPI
MPI_Init(&argc, &argv);
MPI_Comm_size(amgx_mpi_comm, &nranks);
MPI_Comm_rank(amgx_mpi_comm, &rank);
//CUDA GPUs
CUDA_SAFE_CALL(cudaGetDeviceCount(&gpu_count));
lrank = rank % gpu_count;
CUDA_SAFE_CALL(cudaSetDevice(lrank));
printf("Process %d selecting device %d\n", rank, lrank);
/* check arguments */
if (argc == 1)
{
printUsageAndExit();
}
/* load the library (if it was dynamically loaded) */
#ifdef AMGX_DYNAMIC_LOADING
void *lib_handle = NULL;
#ifdef _WIN32
lib_handle = amgx_libopen("amgxsh.dll");
#else
lib_handle = amgx_libopen("libamgxsh.so");
#endif
if (lib_handle == NULL)
{
errAndExit("ERROR: can not load the library");
}
//load all the routines
if (amgx_liblink_all(lib_handle) == 0)
{
amgx_libclose(lib_handle);
errAndExit("ERROR: corrupted library loaded\n");
}
#endif
/* init */
AMGX_SAFE_CALL(AMGX_initialize());
AMGX_SAFE_CALL(AMGX_initialize_plugins());
/* system */
AMGX_SAFE_CALL(AMGX_register_print_callback(&print_callback));
AMGX_SAFE_CALL(AMGX_install_signal_handler());
/* get api and build info */
if ((pidx = findParamIndex(argv, argc, "--version")) != -1)
{
AMGX_get_api_version(&major, &minor);
printf("amgx api version: %d.%d\n", major, minor);
AMGX_get_build_info_strings(&ver, &date, &time);
printf("amgx build version: %s\nBuild date and time: %s %s\n", ver, date, time);
AMGX_SAFE_CALL(AMGX_finalize_plugins());
AMGX_SAFE_CALL(AMGX_finalize());
/* close the library (if it was dynamically loaded) */
#ifdef AMGX_DYNAMIC_LOADING
amgx_libclose(lib_handle);
#endif
MPI_Finalize();
exit(0);
}
/* get mode */
if ((pidx = findParamIndex(argv, argc, "-mode")) != -1)
{
if (strncmp(argv[pidx + 1], "dDDI", 100) == 0)
{
mode = AMGX_mode_dDDI;
}
else if (strncmp(argv[pidx + 1], "dDFI", 100) == 0)
{
mode = AMGX_mode_dDFI;
}
else if (strncmp(argv[pidx + 1], "dFFI", 100) == 0)
{
mode = AMGX_mode_dFFI;
}
else
{
errAndExit("ERROR: invalid mode");
}
}
else
{
printf("Warning: No mode specified, using dDDI by default.\n");
mode = AMGX_mode_dDDI;
}
int sizeof_m_val = ((AMGX_GET_MODE_VAL(AMGX_MatPrecision, mode) == AMGX_matDouble)) ? sizeof(double) : sizeof(float);
int sizeof_v_val = ((AMGX_GET_MODE_VAL(AMGX_VecPrecision, mode) == AMGX_vecDouble)) ? sizeof(double) : sizeof(float);
/* create config */
pidx = findParamIndex(argv, argc, "-amg");
pidy = findParamIndex(argv, argc, "-c");
if ((pidx != -1) && (pidy != -1))
{
printf("%s\n", argv[pidx + 1]);
AMGX_SAFE_CALL(AMGX_config_create_from_file_and_string(&cfg, argv[pidy + 1], argv[pidx + 1]));
}
else if (pidy != -1)
{
AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, argv[pidy + 1]));
}
else if (pidx != -1)
{
printf("%s\n", argv[pidx + 1]);
AMGX_SAFE_CALL(AMGX_config_create(&cfg, argv[pidx + 1]));
}
else
{
errAndExit("ERROR: no config was specified");
}
/* example of how to handle errors */
//char msg[MAX_MSG_LEN];
//AMGX_RC err_code = AMGX_resources_create(NULL, cfg, &amgx_mpi_comm, 1, &lrank);
//AMGX_SAFE_CALL(AMGX_get_error_string(err_code, msg, MAX_MSG_LEN));
//printf("ERROR: %s\n",msg);
/* switch on internal error handling (no need to use AMGX_SAFE_CALL after this point) */
AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg, "exception_handling=1"));
/* create resources, matrix, vector and solver */
AMGX_resources_create(&rsrc, cfg, &amgx_mpi_comm, 1, &lrank);
AMGX_matrix_create(&A, rsrc, mode);
AMGX_vector_create(&x, rsrc, mode);
AMGX_vector_create(&b, rsrc, mode);
AMGX_solver_create(&solver, rsrc, mode, cfg);
//generate 3D Poisson matrix, [and rhs & solution]
//WARNING: use 1 ring for aggregation and 2 rings for classical path
int nrings; //=1; //=2;
AMGX_config_get_default_number_of_rings(cfg, &nrings);
//printf("nrings=%d\n",nrings);
int nglobal = 0;
if ((pidx = findParamIndex(argv, argc, "-p")) != -1)
{
nx = atoi(argv[++pidx]);
ny = atoi(argv[++pidx]);
n = nx * ny; // each rank has the strip of size nx*ny
nglobal = n * nranks; // global domain is just those strips stacked one on each other
}
else
{
printf("Please, use '-p nx ny' parameter for this example\n");
exit(1);
}
nx = 5000;
ny = 1;
n = nx * ny;
nglobal = n * nranks;
/* generate the matrix
In more detail, this routine will create 2D (5 point) discretization of the
Poisson operator. The discretization is performed on a the 2D domain consisting
of nx and ny points in x- and y-dimension respectively. Each rank processes it's
own part of discretization points. Finally, the rhs and solution will be set to
a vector of ones and zeros, respectively. */
int *row_ptrs = (int*)malloc((2 * n + ny + 1) * sizeof(int));
int64_t *col_idxs = (int64_t*)malloc(6 * (2 * n + ny) * sizeof(int64_t));
void *values = malloc(6 * (2 * n + ny) * sizeof(double)); // maximum nnz
int nnz = 0;
int64_t count = 0;
int64_t start_idx = rank * n;
//for (int i = 0; i < n; i ++)
//{
// row_ptrs[i] = nnz;
// if (rank > 0 || i > ny-1)
// {
// col_idxs[nnz] = (i + start_idx - ny);
// if (sizeof_m_val == 4)
// {
// ((float *)values)[nnz] = -1.f;
// }
// else if (sizeof_m_val == 8)
// {
// ((double *)values)[nnz] = -1.;
// }
// nnz++;
// }
// if (i % ny != 0)
// {
// col_idxs[nnz] = (i + start_idx - 1);
// if (sizeof_m_val == 4)
// {
// ((float *)values)[nnz] = -1.f;
// }
// else if (sizeof_m_val == 8)
// {
// ((double *)values)[nnz] = -1.;
// }
// nnz++;
// }
// {
// col_idxs[nnz] = (i + start_idx);
// if (sizeof_m_val == 4)
// {
// ((float *)values)[nnz] = 4.f;
// }
// else if (sizeof_m_val == 8)
// {
// ((double *)values)[nnz] = 4.;
// }
// nnz++;
// }
// if ((i + 1) % ny != 0)
// {
// col_idxs[nnz] = (i + start_idx + 1);
// if (sizeof_m_val == 4)
// {
// ((float *)values)[nnz] = -1.f;
// }
// else if (sizeof_m_val == 8)
// {
// ((double *)values)[nnz] = -1.;
// }
// nnz++;
// }
// if ( (rank != nranks - 1) || (i / ny != (nx - 1)) )
// {
// col_idxs[nnz] = (i + start_idx + ny);
// if (sizeof_m_val == 4)
// {
// ((float *)values)[nnz] = -1.f;
// }
// else if (sizeof_m_val == 8)
// {
// ((double *)values)[nnz] = -1.;
// }
// nnz++;
// }
//}
//row_ptrs[n] = nnz;
///////////////////////////////////////////////////////
///////////////////////////////////////////////////////
double dx = .5;
double gra = 9.81;
double mann = 0.012;
double Du = 0.5;
double Trange = 1.0;
double Tperiod = 43200.0;
int* ndd;
double* u;
double* v;
double* eta;
double* h;
double* detadt;
void* rhs;
// Allocate memory on host
ndd = (int*)malloc(n * sizeof(int));
u = (double*)malloc((n + ny) * sizeof(double));
v = (double*)malloc((n + nx) * sizeof(double));
eta = (double*)malloc(n * sizeof(double));
h = (double*)malloc(n * sizeof(double));
detadt = (double*)malloc(n * sizeof(double));
rhs = malloc((2 * n + ny) * sizeof(double));
// Initialize host memory
for (int i = 0; i < n; i++) {
int row = i / nx; // row index in the 2D work array
int col = i % nx; // col index in the 1D work array
if (col == 0) {
ndd[i] = 2;
}
else {
ndd[i] = 0;
}
eta[i] = 0.0;
h[i] = 2.0;
detadt[i] = M_PI * Trange / Tperiod;
}
for (int i = 0; i < n + ny; i++) {
u[i] = 0.0;
}
for (int i = 0; i < n + nx; i++) {
v[i] = 0.0;
}
// initialize sparse matrix (determine row_ptrs and col_idxs)
nnz = 0;
for (int row = 0; row < ny; row++) {
for (int i = 0; i < 2 * nx + 1; i++) {
// continuity equation
if (i % 2 == 1) {
int col = (i - 1) / 2;
int idx = i + (2 * nx + 1) * row;
int i0 = col + row * nx;
row_ptrs[idx] = nnz;
if (ndd[i0] == 2) {
col_idxs[nnz++] = idx;
}
else {
col_idxs[nnz++] = idx - 1;
col_idxs[nnz++] = idx + 1;
}
}
// momentum equation
if (i % 2 == 0) {
int col = i / 2;
int idx = i + (2 * nx + 1) * row;
int i0 = col + row * nx;
row_ptrs[idx] = nnz;
if (col > 0 && col < nx) {
col_idxs[nnz++] = idx - 2;
col_idxs[nnz++] = idx - 1;
col_idxs[nnz++] = idx;
col_idxs[nnz++] = idx + 1;
col_idxs[nnz++] = idx + 2;
}
else if (col == 0) {
if (ndd[i0] == 2) {
printf("idx = %d, nnz = %d\n", idx, nnz);
col_idxs[nnz++] = idx;
col_idxs[nnz++] = idx + 2;
col_idxs[nnz++] = idx + 4;
}
else {
col_idxs[nnz++] = idx;
}
}
else if (col == nx) {
int i0 = col - 1 + row * nx;
if (ndd[i0] == 2) {
col_idxs[nnz++] = idx - 4;
col_idxs[nnz++] = idx - 2;
col_idxs[nnz++] = idx;
}
else {
col_idxs[nnz++] = idx;
}
}
}
}
}
row_ptrs[2 * n + ny] = nnz;
continuityXdirKernel(n, nx, ny, dx, ndd, u, v, h, detadt, (double*)values, rhs, row_ptrs);
momentumXdirKernel(n + ny, nx + 1, ny, dx, gra, mann, Du, ndd, u, v, eta, h, (double*)values, rhs, row_ptrs);
///////////////////////////////////////////////////////
///////////////////////////////////////////////////////
FILE* fp;
fp = fopen("values.bin", "wb");
fwrite(values, sizeof(double), 6 * (2 * n + ny), fp);
fclose(fp);
fp = fopen("rhs.bin", "wb");
fwrite(rhs, sizeof(double), (2 * n + ny), fp);
fclose(fp);
fp = fopen("row_ptrs.bin", "wb");
fwrite(row_ptrs, sizeof(int), (2 * n + ny + 1), fp);
fclose(fp);
fp = fopen("col_idxs.bin", "wb");
fwrite(col_idxs, sizeof(int64_t), 6 * (2 * n + ny), fp);
fclose(fp);
AMGX_matrix_upload_all_global(A,
2 * n + ny, 2 * n + ny, nnz, 1, 1,
row_ptrs, col_idxs, values, NULL,
nrings, nrings, NULL);
free(values);
free(row_ptrs);
free(col_idxs);
/* generate the rhs and solution */
void *h_x = malloc( (2 * n + ny) * sizeof_v_val);
//for (int i = 0; i < 2 * n + ny; i++) {
// if (i % 2 == 0) {
// ((double*)h_x)[i] = 0.0;
// }
// else {
// ((double*)h_x)[i] = 0.0;
// }
//}
memset(h_x, 50.0, (2 * n + ny) * sizeof_v_val);
/* set the connectivity information (for the vector) */
AMGX_vector_bind(x, A);
AMGX_vector_bind(b, A);
/* upload the vector (and the connectivity information) */
AMGX_vector_upload(x, 2 * n + ny, 1, h_x);
AMGX_vector_upload(b, 2 * n + ny, 1, rhs);
/* solver setup */
//MPI barrier for stability (should be removed in practice to maximize performance)
MPI_Barrier(amgx_mpi_comm);
AMGX_solver_setup(solver, A);
/* solver solve */
//MPI barrier for stability (should be removed in practice to maximize performance)
MPI_Barrier(amgx_mpi_comm);
AMGX_solver_solve(solver, b, x);
/* example of how to change parameters between non-linear iterations */
//AMGX_config_add_parameters(&cfg, "config_version=2, default:tolerance=1e-12");
//AMGX_solver_solve(solver, b, x);
/* example of how to replace coefficients between non-linear iterations */
//AMGX_matrix_replace_coefficients(A, n, nnz, values, diag);
//AMGX_solver_setup(solver, A);
//AMGX_solver_solve(solver, b, x);
AMGX_solver_get_status(solver, &status);
/* example of how to get (the local part of) the solution */
//int sizeof_v_val;
//sizeof_v_val = ((NVAMG_GET_MODE_VAL(NVAMG_VecPrecision, mode) == NVAMG_vecDouble))? sizeof(double): sizeof(float);
//
void* SLN = malloc((2 * n + ny) * sizeof(double));
AMGX_vector_download(x, SLN);
fp = fopen("SLN.bin", "wb");
fwrite(SLN, sizeof(double), (2 * n + ny), fp);
fclose(fp);
//free(result_host);
/* destroy resources, matrix, vector and solver */
AMGX_solver_destroy(solver);
AMGX_vector_destroy(x);
AMGX_vector_destroy(b);
AMGX_matrix_destroy(A);
AMGX_resources_destroy(rsrc);
/* destroy config (need to use AMGX_SAFE_CALL after this point) */
AMGX_SAFE_CALL(AMGX_config_destroy(cfg))
/* shutdown and exit */
AMGX_SAFE_CALL(AMGX_finalize_plugins())
AMGX_SAFE_CALL(AMGX_finalize())
/* close the library (if it was dynamically loaded) */
#ifdef AMGX_DYNAMIC_LOADING
amgx_libclose(lib_handle);
#endif
MPI_Finalize();
CUDA_SAFE_CALL(cudaDeviceReset());
return status;
}
Hello!
Can you provide a matrix itself? I'm collecting data for extreme test cases for my solver and that would be a great help! Plus, i might be able to give some advice on the configuration of AMGX if i have the matrix.
Hello,
I'm working on solving a matrix equation that seems to be ill-conditioned, and I'm encountering convergence issues with all the configuration files provided in AMGX. I've attached a JPEG image of the matrix to give an idea of its structure.
Despite confirming that the matrix rank is appropriate, I'm struggling with the equation's principal axis, which contains many zero values. This characteristic might be contributing to the convergence problems.
I'm seeking advice on how to properly configure AMGX for this type of equation. To provide more context, I've also attached a snippet of my code that illustrates how I'm constructing and attempting to solve this system.
Any suggestions or guidance on the appropriate AMGX configuration settings for this scenario would be greatly appreciated.
Thank you for your time and help.
//// below is code
/// below is the message returned
C:\STUDY\testamgx\CudaRuntime1\x64\Release>CudaRuntime1.exe -p 3 3 -c C:\AMGX\AMGX-2.3.0\core\configs\FGMRES_NOPREC.json
Process 0 selecting device 0
AMGX version 2.2.0.132-opensource
Built on Nov 23 2023, 14:14:00
Compiled with CUDA Runtime 11.8, using CUDA driver 12.2
Warning: No mode specified, using dDDI by default.
Cannot read file as JSON object, trying as AMGX config
Converting config string to current config version
Parsing configuration string: exception_handling=1 ;
idx = 0, nnz = 0
Using Normal MPI (Hostbuffer) communicator...
iter Mem Usage (GB) residual rate
--------------------------------------------------------------
Ini 0 5.141711e-03
0 0 5.141711e-03 1.0000
1 0.0000 5.141369e-03 0.9999
2 0.0000 5.141338e-03 1.0000
3 0.0000 5.141076e-03 0.9999
4 0.0000 5.141015e-03 1.0000
5 0.0000 5.140766e-03 1.0000
6 0.0000 5.140733e-03 1.0000
7 0.0000 5.140463e-03 0.9999
8 0.0000 5.140457e-03 1.0000
9 0.0000 5.140179e-03 0.9999
10 0.0000 5.140179e-03 1.0000
11 0.0000 5.139921e-03 0.9999
12 0.0000 5.139921e-03 1.0000
13 0.0000 5.139696e-03 1.0000
14 0.0000 5.139696e-03 1.0000
15 0.0000 5.139506e-03 1.0000
16 0.0000 5.139504e-03 1.0000
17 0.0000 5.139340e-03 1.0000
18 0.0000 5.139340e-03 1.0000
19 0.0000 5.139190e-03 1.0000
20 0.0000 5.139190e-03 1.0000
21 0.0000 5.139173e-03 1.0000
22 0.0000 5.139173e-03 1.0000
23 0.0000 5.139165e-03 1.0000
24 0.0000 5.139165e-03 1.0000
25 0.0000 5.139152e-03 1.0000
26 0.0000 5.139150e-03 1.0000
27 0.0000 5.139132e-03 1.0000
28 0.0000 5.139132e-03 1.0000
29 0.0000 5.139103e-03 1.0000
30 0.0000 5.139097e-03 1.0000
31 0.0000 5.139063e-03 1.0000
32 0.0000 5.139048e-03 1.0000
33 0.0000 5.139012e-03 1.0000
34 0.0000 5.138989e-03 1.0000
35 0.0000 5.138955e-03 1.0000
36 0.0000 5.138920e-03 1.0000
37 0.0000 5.138907e-03 1.0000
38 0.0000 5.138854e-03 1.0000
39 0.0000 5.138854e-03 1.0000
40 0.0000 5.138854e-03 1.0000
41 0.0000 5.138852e-03 1.0000
42 0.0000 5.138851e-03 1.0000
43 0.0000 5.138846e-03 1.0000
44 0.0000 5.138846e-03 1.0000
45 0.0000 5.138834e-03 1.0000
46 0.0000 5.138834e-03 1.0000
47 0.0000 5.138827e-03 1.0000
48 0.0000 5.138827e-03 1.0000
49 0.0000 5.138819e-03 1.0000
50 0.0000 5.138819e-03 1.0000
51 0.0000 5.138807e-03 1.0000
52 0.0000 5.138807e-03 1.0000
53 0.0000 5.138803e-03 1.0000
54 0.0000 5.138802e-03 1.0000
55 0.0000 5.138796e-03 1.0000
56 0.0000 5.138795e-03 1.0000
57 0.0000 5.138795e-03 1.0000
58 0.0000 5.138795e-03 1.0000
59 0.0000 5.138793e-03 1.0000
60 0.0000 5.138793e-03 1.0000
61 0.0000 5.138793e-03 1.0000
62 0.0000 5.138793e-03 1.0000
63 0.0000 5.138793e-03 1.0000
64 0.0000 5.138793e-03 1.0000
65 0.0000 5.138793e-03 1.0000
66 0.0000 5.138793e-03 1.0000
67 0.0000 5.138793e-03 1.0000
68 0.0000 5.138792e-03 1.0000
69 0.0000 5.138792e-03 1.0000
70 0.0000 5.138792e-03 1.0000
71 0.0000 5.138791e-03 1.0000
72 0.0000 5.138789e-03 1.0000
73 0.0000 5.138789e-03 1.0000
74 0.0000 5.138787e-03 1.0000
75 0.0000 5.138785e-03 1.0000
76 0.0000 5.138783e-03 1.0000
77 0.0000 5.138783e-03 1.0000
78 0.0000 5.138777e-03 1.0000
79 0.0000 5.138777e-03 1.0000
80 0.0000 5.138777e-03 1.0000
81 0.0000 5.138777e-03 1.0000
82 0.0000 5.138777e-03 1.0000
83 0.0000 5.138776e-03 1.0000
84 0.0000 5.138776e-03 1.0000
85 0.0000 5.138775e-03 1.0000
86 0.0000 5.138775e-03 1.0000
87 0.0000 5.138774e-03 1.0000
88 0.0000 5.138774e-03 1.0000
89 0.0000 5.138774e-03 1.0000
90 0.0000 5.138773e-03 1.0000
91 0.0000 5.138772e-03 1.0000
92 0.0000 5.138772e-03 1.0000
93 0.0000 5.138772e-03 1.0000
94 0.0000 5.138772e-03 1.0000
95 0.0000 5.138772e-03 1.0000
96 0.0000 5.138772e-03 1.0000
97 0.0000 5.138771e-03 1.0000
98 0.0000 5.138771e-03 1.0000
99 0.0000 5.138771e-03 1.0000
--------------------------------------------------------------
Total Iterations: 100
Avg Convergence Rate: 1.0000
Final Residual: 5.138771e-03
Total Reduction in Residual: 9.994281e-01
Maximum Memory Usage: 0.000 GB
--------------------------------------------------------------
Total Time: 0.0932096
setup: 8.192e-05 s
solve: 0.0931277 s
solve(per iteration): 0.000931277 s
/// below is the shape of the sparse matrix
///
The text was updated successfully, but these errors were encountered: