Skip to content

Commit

Permalink
Support the SA_FS task for HIP. (#90)
Browse files Browse the repository at this point in the history
Details:
- Add a rocBLAS-backed implementation for SA_Apply_pivots.
- Add a version of SA_FS_blk that uses HIP wrappers.
- Add a method to map a FLA_Obj view to its corresponding raw HIP buffer.
- Enable SA_FS task for HIP and handle in the HIP queue.
  • Loading branch information
iotamudelta authored Feb 16, 2023
1 parent 74b56a2 commit d48aa5d
Show file tree
Hide file tree
Showing 7 changed files with 289 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/base/flamec/include/FLA_lapack_prototypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ FLA_Error FLA_LQ_unb_ext_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA
FLA_Error FLA_QR_form_Q_external_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip );
FLA_Error FLA_QR_unb_external_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip );
FLA_Error FLA_QR_unb_ext_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip );
FLA_Error FLA_SA_Apply_pivots_hip( rocblas_handle handle, FLA_Obj C, void* C_hip, FLA_Obj E, void* E_hip, FLA_Obj p );
FLA_Error FLA_SA_FS_blk_hip( rocblas_handle handle, FLA_Obj L, FLA_Obj D, void* D_hip, FLA_Obj p, FLA_Obj C, void* C_hip,FLA_Obj E, void* E_hip, dim_t nb_alg );
FLA_Error FLA_Svd_external_hip( rocblas_handle handle, FLA_Svd_type jobu, FLA_Svd_type jobv, FLA_Obj A, void* A_hip, FLA_Obj s, void* s_hip, FLA_Obj U, void* U_hip, FLA_Obj V, void* V_hip );
FLA_Error FLA_Tevdd_external_hip( rocblas_handle handle, FLA_Evd_type jobz, FLA_Obj d, void* d_hip, FLA_Obj e, void* e_hip, FLA_Obj A, void* A_hip );
FLA_Error FLA_Tevd_external_hip( rocblas_handle handle, FLA_Evd_type jobz, FLA_Obj d, void* d_hip, FLA_Obj e, void* e_hip, FLA_Obj A, void* A_hip );
Expand Down
3 changes: 3 additions & 0 deletions src/base/flamec/include/FLA_main_prototypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ dim_t FLA_Obj_base_width( FLA_Obj obj );
dim_t FLA_Obj_num_elem_alloc( FLA_Obj obj );
void* FLA_Obj_base_buffer( FLA_Obj obj );
void* FLA_Obj_buffer_at_view( FLA_Obj obj );
#ifdef FLA_ENABLE_HIP
void* FLA_Obj_hip_buffer_at_view( FLA_Obj obj, void* hip_buffer );
#endif
FLA_Bool FLA_Obj_buffer_is_null( FLA_Obj obj );
FLA_Bool FLA_Obj_is_int( FLA_Obj A );
FLA_Bool FLA_Obj_is_floating_point( FLA_Obj A );
Expand Down
23 changes: 23 additions & 0 deletions src/base/flamec/main/FLA_Query.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,29 @@ void* FLA_Obj_buffer_at_view( FLA_Obj obj )
return ( void* ) ( buffer + byte_offset );
}

#ifdef FLA_ENABLE_HIP
void* FLA_Obj_hip_buffer_at_view( FLA_Obj obj, void* hip_buffer )
{
char* buffer;
size_t elem_size, offm, offn, rs, cs;
size_t byte_offset;

if ( FLA_Check_error_level() >= FLA_MIN_ERROR_CHECKING )
FLA_Obj_buffer_at_view_check( obj );

elem_size = ( size_t ) FLA_Obj_elem_size( obj );
rs = ( size_t ) FLA_Obj_row_stride( obj );
cs = ( size_t ) FLA_Obj_col_stride( obj );
offm = ( size_t ) obj.offm;
offn = ( size_t ) obj.offn;

byte_offset = elem_size * ( offn * cs + offm * rs );

buffer = ( char * ) hip_buffer;

return ( void* ) ( buffer + byte_offset );
}
#endif


FLA_Bool FLA_Obj_buffer_is_null( FLA_Obj obj )
Expand Down
18 changes: 18 additions & 0 deletions src/base/flamec/supermatrix/hip/main/FLASH_Queue_hip.c
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ void FLASH_Queue_exec_task_hip( FLASH_Task* t,
typedef FLA_Error(*flash_eig_gest_hip_p)(rocblas_handle handle, FLA_Inv inv, FLA_Uplo uplo, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip );
typedef FLA_Error(*flash_lu_piv_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj p );
typedef FLA_Error(*flash_lu_piv_copy_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj p, FLA_Obj U, void* U_hip );
typedef FLA_Error(*flash_sa_fs_hip_p)(rocblas_handle handle, FLA_Obj L, FLA_Obj D, void* D_hip, FLA_Obj p, FLA_Obj C, void* C_hip,FLA_Obj E, void* E_hip, dim_t nb_alg );
typedef FLA_Error(*flash_trsm_piv_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip, FLA_Obj p );

// Level-3 BLAS
Expand Down Expand Up @@ -563,6 +564,23 @@ void FLASH_Queue_exec_task_hip( FLASH_Task* t,
t->output_arg[1],
output_arg[1] );
}
// FLA_SA_FS
else if ( t->func == (void *) FLA_SA_FS_task )
{
flash_sa_fs_hip_p func;
func = (flash_sa_fs_hip_p) FLA_SA_FS_blk_hip;

func( handle,
t->fla_arg[0],
t->input_arg[0],
input_arg[0],
t->fla_arg[1],
t->output_arg[1],
output_arg[1],
t->output_arg[0],
output_arg[0],
t->int_arg[0] );
}
// FLA_Trsm_piv
else if ( t->func == (void *) FLA_Trsm_piv_task )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ also to create a macro for when it is not below to return an error code.
(void *) cntl, \
"SA_FS", \
FALSE, \
FALSE, \
TRUE, \
1, 2, 1, 2, \
nb_alg, \
L, p, D, E, C )
Expand Down
126 changes: 126 additions & 0 deletions src/base/flamec/wrappers/lapack/hip/FLA_SA_Apply_pivots_hip.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2023, Advanced Micro Devices, Inc.
This file is part of libflame and is available under the 3-Clause
BSD license, which can be found in the LICENSE file at the top-level
directory, or at http://opensource.org/licenses/BSD-3-Clause
*/

#include "FLAME.h"

#ifdef FLA_ENABLE_HIP

#include "hip/hip_runtime_api.h"
#include "rocblas/rocblas.h"

FLA_Error FLA_SA_Apply_pivots_hip( rocblas_handle handle, FLA_Obj C, void* C_hip, FLA_Obj E, void* E_hip, FLA_Obj p )
{
FLA_Datatype datatype;
int m_C, n_C, cs_C;
int cs_E;
// int rs_C;
// int rs_E;
int m_p;
int i;
int* buff_p;

if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;

datatype = FLA_Obj_datatype( C );

m_C = FLA_Obj_length( C );
n_C = FLA_Obj_width( C );
cs_C = FLA_Obj_col_stride( C );
// rs_C = FLA_Obj_row_stride( C );

cs_E = FLA_Obj_col_stride( E );
// rs_E = FLA_Obj_row_stride( E );

m_p = FLA_Obj_length( p );

buff_p = ( int * ) FLA_INT_PTR( p );

void* C_mat = NULL;
void* E_mat = NULL;
if ( FLASH_Queue_get_malloc_managed_enabled_hip() )
{
C_mat = FLA_Obj_buffer_at_view( C );
E_mat = FLA_Obj_buffer_at_view( E );
}
else
{
C_mat = C_hip;
E_mat = E_hip;
}

switch ( datatype ){

case FLA_FLOAT:
{
float* buff_C = ( float * ) C_mat;
float* buff_E = ( float * ) E_mat;

for ( i = 0; i < m_p; ++i )
{
if ( buff_p[ i ] != 0 )
rocblas_sswap( handle, n_C,
buff_C + 0*cs_C + i, cs_C,
buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E );
}
break;
}

case FLA_DOUBLE:
{
double* buff_C = ( double * ) C_mat;
double* buff_E = ( double * ) E_mat;

for ( i = 0; i < m_p; ++i )
{
if ( buff_p[ i ] != 0 )
rocblas_dswap( handle, n_C,
buff_C + 0*cs_C + i, cs_C,
buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E );
}
break;
}

case FLA_COMPLEX:
{
rocblas_float_complex* buff_C = ( rocblas_float_complex * ) C_mat;
rocblas_float_complex* buff_E = ( rocblas_float_complex * ) E_mat;

for ( i = 0; i < m_p; ++i )
{
if ( buff_p[ i ] != 0 )
rocblas_cswap( handle, n_C,
buff_C + 0*cs_C + i, cs_C,
buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E );
}
break;
}

case FLA_DOUBLE_COMPLEX:
{
rocblas_double_complex* buff_C = ( rocblas_double_complex * ) C_mat;
rocblas_double_complex* buff_E = ( rocblas_double_complex * ) E_mat;

for ( i = 0; i < m_p; ++i )
{
if ( buff_p[ i ] != 0 )
rocblas_zswap( handle, n_C,
buff_C + 0*cs_C + i, cs_C,
buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E );
}
break;
}

}

return FLA_SUCCESS;
}

#endif
116 changes: 116 additions & 0 deletions src/base/flamec/wrappers/lapack/hip/FLA_SA_FS_blk_hip.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2023, Advanced Micro Devices, Inc.
This file is part of libflame and is available under the 3-Clause
BSD license, which can be found in the LICENSE file at the top-level
directory, or at http://opensource.org/licenses/BSD-3-Clause
*/

#include "FLAME.h"

#ifdef FLA_ENABLE_HIP

#include "hip/hip_runtime_api.h"
#include "rocblas/rocblas.h"

FLA_Error FLA_SA_FS_blk_hip( rocblas_handle handle, FLA_Obj L,
FLA_Obj D, void* D_hip, FLA_Obj p, FLA_Obj C, void* C_hip,
FLA_Obj E, void* E_hip, dim_t nb_alg )
{

FLA_Obj LT, L0,
LB, L1,
L2;

FLA_Obj DL, DR, D0, D1, D2;

FLA_Obj pT, p0,
pB, p1,
p2;

FLA_Obj CT, C0,
CB, C1,
C2;

FLA_Obj L1_sqr, L1_rest;

dim_t b;

FLA_Part_2x1( L, &LT,
&LB, 0, FLA_TOP );

FLA_Part_1x2( D, &DL, &DR, 0, FLA_LEFT );

FLA_Part_2x1( p, &pT,
&pB, 0, FLA_TOP );

FLA_Part_2x1( C, &CT,
&CB, 0, FLA_TOP );

while ( FLA_Obj_length( LT ) < FLA_Obj_length( L ) )
{
b = min( FLA_Obj_length( LB ), nb_alg );

FLA_Repart_2x1_to_3x1( LT, &L0,
/* ** */ /* ** */
&L1,
LB, &L2, b, FLA_BOTTOM );

FLA_Repart_1x2_to_1x3( DL, /**/ DR, &D0, /**/ &D1, &D2,
b, FLA_RIGHT );

FLA_Repart_2x1_to_3x1( pT, &p0,
/* ** */ /* ** */
&p1,
pB, &p2, b, FLA_BOTTOM );

FLA_Repart_2x1_to_3x1( CT, &C0,
/* ** */ /* ** */
&C1,
CB, &C2, b, FLA_BOTTOM );

/*------------------------------------------------------------*/

FLA_Part_1x2( L1, &L1_sqr, &L1_rest, b, FLA_LEFT );


FLA_SA_Apply_pivots_hip( handle, C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ),
E, E_hip, p1 );

FLA_Trsm_external_hip( handle, FLA_LEFT, FLA_LOWER_TRIANGULAR,
FLA_NO_TRANSPOSE, FLA_UNIT_DIAG,
FLA_ONE, L1_sqr, FLA_Obj_buffer_at_view ( L1_sqr ),
C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ) );

FLA_Gemm_external_hip( handle, FLA_NO_TRANSPOSE, FLA_NO_TRANSPOSE,
FLA_MINUS_ONE, D1, FLA_Obj_hip_buffer_at_view( D1, D_hip ),
C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ), FLA_ONE, E, E_hip );

/*------------------------------------------------------------*/

FLA_Cont_with_3x1_to_2x1( &LT, L0,
L1,
/* ** */ /* ** */
&LB, L2, FLA_TOP );

FLA_Cont_with_1x3_to_1x2( &DL, /**/ &DR, D0, D1, /**/ D2,
FLA_LEFT );

FLA_Cont_with_3x1_to_2x1( &pT, p0,
p1,
/* ** */ /* ** */
&pB, p2, FLA_TOP );

FLA_Cont_with_3x1_to_2x1( &CT, C0,
C1,
/* ** */ /* ** */
&CB, C2, FLA_TOP );
}

return FLA_SUCCESS;
}

#endif

0 comments on commit d48aa5d

Please sign in to comment.