-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support the SA_FS task for HIP. (#90)
Details: - Add a rocBLAS-backed implementation for SA_Apply_pivots. - Add a version of SA_FS_blk that uses HIP wrappers. - Add a method to map a FLA_Obj view to its corresponding raw HIP buffer. - Enable SA_FS task for HIP and handle in the HIP queue.
- Loading branch information
1 parent
74b56a2
commit d48aa5d
Showing
7 changed files
with
289 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
src/base/flamec/wrappers/lapack/hip/FLA_SA_Apply_pivots_hip.c
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
/* | ||
Copyright (C) 2014, The University of Texas at Austin | ||
Copyright (C) 2023, Advanced Micro Devices, Inc. | ||
This file is part of libflame and is available under the 3-Clause | ||
BSD license, which can be found in the LICENSE file at the top-level | ||
directory, or at http://opensource.org/licenses/BSD-3-Clause | ||
*/ | ||
|
||
#include "FLAME.h" | ||
|
||
#ifdef FLA_ENABLE_HIP | ||
|
||
#include "hip/hip_runtime_api.h" | ||
#include "rocblas/rocblas.h" | ||
|
||
FLA_Error FLA_SA_Apply_pivots_hip( rocblas_handle handle, FLA_Obj C, void* C_hip, FLA_Obj E, void* E_hip, FLA_Obj p ) | ||
{ | ||
FLA_Datatype datatype; | ||
int m_C, n_C, cs_C; | ||
int cs_E; | ||
// int rs_C; | ||
// int rs_E; | ||
int m_p; | ||
int i; | ||
int* buff_p; | ||
|
||
if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS; | ||
|
||
datatype = FLA_Obj_datatype( C ); | ||
|
||
m_C = FLA_Obj_length( C ); | ||
n_C = FLA_Obj_width( C ); | ||
cs_C = FLA_Obj_col_stride( C ); | ||
// rs_C = FLA_Obj_row_stride( C ); | ||
|
||
cs_E = FLA_Obj_col_stride( E ); | ||
// rs_E = FLA_Obj_row_stride( E ); | ||
|
||
m_p = FLA_Obj_length( p ); | ||
|
||
buff_p = ( int * ) FLA_INT_PTR( p ); | ||
|
||
void* C_mat = NULL; | ||
void* E_mat = NULL; | ||
if ( FLASH_Queue_get_malloc_managed_enabled_hip() ) | ||
{ | ||
C_mat = FLA_Obj_buffer_at_view( C ); | ||
E_mat = FLA_Obj_buffer_at_view( E ); | ||
} | ||
else | ||
{ | ||
C_mat = C_hip; | ||
E_mat = E_hip; | ||
} | ||
|
||
switch ( datatype ){ | ||
|
||
case FLA_FLOAT: | ||
{ | ||
float* buff_C = ( float * ) C_mat; | ||
float* buff_E = ( float * ) E_mat; | ||
|
||
for ( i = 0; i < m_p; ++i ) | ||
{ | ||
if ( buff_p[ i ] != 0 ) | ||
rocblas_sswap( handle, n_C, | ||
buff_C + 0*cs_C + i, cs_C, | ||
buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E ); | ||
} | ||
break; | ||
} | ||
|
||
case FLA_DOUBLE: | ||
{ | ||
double* buff_C = ( double * ) C_mat; | ||
double* buff_E = ( double * ) E_mat; | ||
|
||
for ( i = 0; i < m_p; ++i ) | ||
{ | ||
if ( buff_p[ i ] != 0 ) | ||
rocblas_dswap( handle, n_C, | ||
buff_C + 0*cs_C + i, cs_C, | ||
buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E ); | ||
} | ||
break; | ||
} | ||
|
||
case FLA_COMPLEX: | ||
{ | ||
rocblas_float_complex* buff_C = ( rocblas_float_complex * ) C_mat; | ||
rocblas_float_complex* buff_E = ( rocblas_float_complex * ) E_mat; | ||
|
||
for ( i = 0; i < m_p; ++i ) | ||
{ | ||
if ( buff_p[ i ] != 0 ) | ||
rocblas_cswap( handle, n_C, | ||
buff_C + 0*cs_C + i, cs_C, | ||
buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E ); | ||
} | ||
break; | ||
} | ||
|
||
case FLA_DOUBLE_COMPLEX: | ||
{ | ||
rocblas_double_complex* buff_C = ( rocblas_double_complex * ) C_mat; | ||
rocblas_double_complex* buff_E = ( rocblas_double_complex * ) E_mat; | ||
|
||
for ( i = 0; i < m_p; ++i ) | ||
{ | ||
if ( buff_p[ i ] != 0 ) | ||
rocblas_zswap( handle, n_C, | ||
buff_C + 0*cs_C + i, cs_C, | ||
buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E ); | ||
} | ||
break; | ||
} | ||
|
||
} | ||
|
||
return FLA_SUCCESS; | ||
} | ||
|
||
#endif |
116 changes: 116 additions & 0 deletions
116
src/base/flamec/wrappers/lapack/hip/FLA_SA_FS_blk_hip.c
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
/* | ||
Copyright (C) 2014, The University of Texas at Austin | ||
Copyright (C) 2023, Advanced Micro Devices, Inc. | ||
This file is part of libflame and is available under the 3-Clause | ||
BSD license, which can be found in the LICENSE file at the top-level | ||
directory, or at http://opensource.org/licenses/BSD-3-Clause | ||
*/ | ||
|
||
#include "FLAME.h" | ||
|
||
#ifdef FLA_ENABLE_HIP | ||
|
||
#include "hip/hip_runtime_api.h" | ||
#include "rocblas/rocblas.h" | ||
|
||
FLA_Error FLA_SA_FS_blk_hip( rocblas_handle handle, FLA_Obj L, | ||
FLA_Obj D, void* D_hip, FLA_Obj p, FLA_Obj C, void* C_hip, | ||
FLA_Obj E, void* E_hip, dim_t nb_alg ) | ||
{ | ||
|
||
FLA_Obj LT, L0, | ||
LB, L1, | ||
L2; | ||
|
||
FLA_Obj DL, DR, D0, D1, D2; | ||
|
||
FLA_Obj pT, p0, | ||
pB, p1, | ||
p2; | ||
|
||
FLA_Obj CT, C0, | ||
CB, C1, | ||
C2; | ||
|
||
FLA_Obj L1_sqr, L1_rest; | ||
|
||
dim_t b; | ||
|
||
FLA_Part_2x1( L, <, | ||
&LB, 0, FLA_TOP ); | ||
|
||
FLA_Part_1x2( D, &DL, &DR, 0, FLA_LEFT ); | ||
|
||
FLA_Part_2x1( p, &pT, | ||
&pB, 0, FLA_TOP ); | ||
|
||
FLA_Part_2x1( C, &CT, | ||
&CB, 0, FLA_TOP ); | ||
|
||
while ( FLA_Obj_length( LT ) < FLA_Obj_length( L ) ) | ||
{ | ||
b = min( FLA_Obj_length( LB ), nb_alg ); | ||
|
||
FLA_Repart_2x1_to_3x1( LT, &L0, | ||
/* ** */ /* ** */ | ||
&L1, | ||
LB, &L2, b, FLA_BOTTOM ); | ||
|
||
FLA_Repart_1x2_to_1x3( DL, /**/ DR, &D0, /**/ &D1, &D2, | ||
b, FLA_RIGHT ); | ||
|
||
FLA_Repart_2x1_to_3x1( pT, &p0, | ||
/* ** */ /* ** */ | ||
&p1, | ||
pB, &p2, b, FLA_BOTTOM ); | ||
|
||
FLA_Repart_2x1_to_3x1( CT, &C0, | ||
/* ** */ /* ** */ | ||
&C1, | ||
CB, &C2, b, FLA_BOTTOM ); | ||
|
||
/*------------------------------------------------------------*/ | ||
|
||
FLA_Part_1x2( L1, &L1_sqr, &L1_rest, b, FLA_LEFT ); | ||
|
||
|
||
FLA_SA_Apply_pivots_hip( handle, C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ), | ||
E, E_hip, p1 ); | ||
|
||
FLA_Trsm_external_hip( handle, FLA_LEFT, FLA_LOWER_TRIANGULAR, | ||
FLA_NO_TRANSPOSE, FLA_UNIT_DIAG, | ||
FLA_ONE, L1_sqr, FLA_Obj_buffer_at_view ( L1_sqr ), | ||
C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ) ); | ||
|
||
FLA_Gemm_external_hip( handle, FLA_NO_TRANSPOSE, FLA_NO_TRANSPOSE, | ||
FLA_MINUS_ONE, D1, FLA_Obj_hip_buffer_at_view( D1, D_hip ), | ||
C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ), FLA_ONE, E, E_hip ); | ||
|
||
/*------------------------------------------------------------*/ | ||
|
||
FLA_Cont_with_3x1_to_2x1( <, L0, | ||
L1, | ||
/* ** */ /* ** */ | ||
&LB, L2, FLA_TOP ); | ||
|
||
FLA_Cont_with_1x3_to_1x2( &DL, /**/ &DR, D0, D1, /**/ D2, | ||
FLA_LEFT ); | ||
|
||
FLA_Cont_with_3x1_to_2x1( &pT, p0, | ||
p1, | ||
/* ** */ /* ** */ | ||
&pB, p2, FLA_TOP ); | ||
|
||
FLA_Cont_with_3x1_to_2x1( &CT, C0, | ||
C1, | ||
/* ** */ /* ** */ | ||
&CB, C2, FLA_TOP ); | ||
} | ||
|
||
return FLA_SUCCESS; | ||
} | ||
|
||
#endif |