diff --git a/src/base/flamec/include/FLA_lapack_prototypes.h b/src/base/flamec/include/FLA_lapack_prototypes.h index a54d49e38..ec98eecad 100644 --- a/src/base/flamec/include/FLA_lapack_prototypes.h +++ b/src/base/flamec/include/FLA_lapack_prototypes.h @@ -237,6 +237,8 @@ FLA_Error FLA_LQ_unb_ext_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA FLA_Error FLA_QR_form_Q_external_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip ); FLA_Error FLA_QR_unb_external_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip ); FLA_Error FLA_QR_unb_ext_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip ); +FLA_Error FLA_SA_Apply_pivots_hip( rocblas_handle handle, FLA_Obj C, void* C_hip, FLA_Obj E, void* E_hip, FLA_Obj p ); +FLA_Error FLA_SA_FS_blk_hip( rocblas_handle handle, FLA_Obj L, FLA_Obj D, void* D_hip, FLA_Obj p, FLA_Obj C, void* C_hip,FLA_Obj E, void* E_hip, dim_t nb_alg ); FLA_Error FLA_Svd_external_hip( rocblas_handle handle, FLA_Svd_type jobu, FLA_Svd_type jobv, FLA_Obj A, void* A_hip, FLA_Obj s, void* s_hip, FLA_Obj U, void* U_hip, FLA_Obj V, void* V_hip ); FLA_Error FLA_Tevdd_external_hip( rocblas_handle handle, FLA_Evd_type jobz, FLA_Obj d, void* d_hip, FLA_Obj e, void* e_hip, FLA_Obj A, void* A_hip ); FLA_Error FLA_Tevd_external_hip( rocblas_handle handle, FLA_Evd_type jobz, FLA_Obj d, void* d_hip, FLA_Obj e, void* e_hip, FLA_Obj A, void* A_hip ); diff --git a/src/base/flamec/include/FLA_main_prototypes.h b/src/base/flamec/include/FLA_main_prototypes.h index 626c01e42..3473ebafb 100644 --- a/src/base/flamec/include/FLA_main_prototypes.h +++ b/src/base/flamec/include/FLA_main_prototypes.h @@ -301,6 +301,9 @@ dim_t FLA_Obj_base_width( FLA_Obj obj ); dim_t FLA_Obj_num_elem_alloc( FLA_Obj obj ); void* FLA_Obj_base_buffer( FLA_Obj obj ); void* FLA_Obj_buffer_at_view( FLA_Obj obj ); +#ifdef FLA_ENABLE_HIP +void* FLA_Obj_hip_buffer_at_view( FLA_Obj obj, void* hip_buffer ); +#endif FLA_Bool FLA_Obj_buffer_is_null( FLA_Obj obj ); FLA_Bool FLA_Obj_is_int( FLA_Obj A ); FLA_Bool FLA_Obj_is_floating_point( FLA_Obj A ); diff --git a/src/base/flamec/main/FLA_Query.c b/src/base/flamec/main/FLA_Query.c index a2d6693e5..90acb057c 100644 --- a/src/base/flamec/main/FLA_Query.c +++ b/src/base/flamec/main/FLA_Query.c @@ -234,6 +234,29 @@ void* FLA_Obj_buffer_at_view( FLA_Obj obj ) return ( void* ) ( buffer + byte_offset ); } +#ifdef FLA_ENABLE_HIP +void* FLA_Obj_hip_buffer_at_view( FLA_Obj obj, void* hip_buffer ) +{ + char* buffer; + size_t elem_size, offm, offn, rs, cs; + size_t byte_offset; + + if ( FLA_Check_error_level() >= FLA_MIN_ERROR_CHECKING ) + FLA_Obj_buffer_at_view_check( obj ); + + elem_size = ( size_t ) FLA_Obj_elem_size( obj ); + rs = ( size_t ) FLA_Obj_row_stride( obj ); + cs = ( size_t ) FLA_Obj_col_stride( obj ); + offm = ( size_t ) obj.offm; + offn = ( size_t ) obj.offn; + + byte_offset = elem_size * ( offn * cs + offm * rs ); + + buffer = ( char * ) hip_buffer; + + return ( void* ) ( buffer + byte_offset ); +} +#endif FLA_Bool FLA_Obj_buffer_is_null( FLA_Obj obj ) diff --git a/src/base/flamec/supermatrix/hip/main/FLASH_Queue_hip.c b/src/base/flamec/supermatrix/hip/main/FLASH_Queue_hip.c index 3eec45d03..888e22d0f 100644 --- a/src/base/flamec/supermatrix/hip/main/FLASH_Queue_hip.c +++ b/src/base/flamec/supermatrix/hip/main/FLASH_Queue_hip.c @@ -469,6 +469,7 @@ void FLASH_Queue_exec_task_hip( FLASH_Task* t, typedef FLA_Error(*flash_eig_gest_hip_p)(rocblas_handle handle, FLA_Inv inv, FLA_Uplo uplo, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip ); typedef FLA_Error(*flash_lu_piv_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj p ); typedef FLA_Error(*flash_lu_piv_copy_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj p, FLA_Obj U, void* U_hip ); + typedef FLA_Error(*flash_sa_fs_hip_p)(rocblas_handle handle, FLA_Obj L, FLA_Obj D, void* D_hip, FLA_Obj p, FLA_Obj C, void* C_hip,FLA_Obj E, void* E_hip, dim_t nb_alg ); typedef FLA_Error(*flash_trsm_piv_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip, FLA_Obj p ); // Level-3 BLAS @@ -563,6 +564,23 @@ void FLASH_Queue_exec_task_hip( FLASH_Task* t, t->output_arg[1], output_arg[1] ); } + // FLA_SA_FS + else if ( t->func == (void *) FLA_SA_FS_task ) + { + flash_sa_fs_hip_p func; + func = (flash_sa_fs_hip_p) FLA_SA_FS_blk_hip; + + func( handle, + t->fla_arg[0], + t->input_arg[0], + input_arg[0], + t->fla_arg[1], + t->output_arg[1], + output_arg[1], + t->output_arg[0], + output_arg[0], + t->int_arg[0] ); + } // FLA_Trsm_piv else if ( t->func == (void *) FLA_Trsm_piv_task ) { diff --git a/src/base/flamec/supermatrix/include/FLASH_Queue_macro_defs.h b/src/base/flamec/supermatrix/include/FLASH_Queue_macro_defs.h index 60aa4ac94..8152689da 100644 --- a/src/base/flamec/supermatrix/include/FLASH_Queue_macro_defs.h +++ b/src/base/flamec/supermatrix/include/FLASH_Queue_macro_defs.h @@ -97,7 +97,7 @@ also to create a macro for when it is not below to return an error code. (void *) cntl, \ "SA_FS", \ FALSE, \ - FALSE, \ + TRUE, \ 1, 2, 1, 2, \ nb_alg, \ L, p, D, E, C ) diff --git a/src/base/flamec/wrappers/lapack/hip/FLA_SA_Apply_pivots_hip.c b/src/base/flamec/wrappers/lapack/hip/FLA_SA_Apply_pivots_hip.c new file mode 100644 index 000000000..8d0f9bc6e --- /dev/null +++ b/src/base/flamec/wrappers/lapack/hip/FLA_SA_Apply_pivots_hip.c @@ -0,0 +1,126 @@ +/* + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. + + This file is part of libflame and is available under the 3-Clause + BSD license, which can be found in the LICENSE file at the top-level + directory, or at http://opensource.org/licenses/BSD-3-Clause + +*/ + +#include "FLAME.h" + +#ifdef FLA_ENABLE_HIP + +#include "hip/hip_runtime_api.h" +#include "rocblas/rocblas.h" + +FLA_Error FLA_SA_Apply_pivots_hip( rocblas_handle handle, FLA_Obj C, void* C_hip, FLA_Obj E, void* E_hip, FLA_Obj p ) +{ + FLA_Datatype datatype; + int m_C, n_C, cs_C; + int cs_E; + // int rs_C; + // int rs_E; + int m_p; + int i; + int* buff_p; + + if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS; + + datatype = FLA_Obj_datatype( C ); + + m_C = FLA_Obj_length( C ); + n_C = FLA_Obj_width( C ); + cs_C = FLA_Obj_col_stride( C ); + // rs_C = FLA_Obj_row_stride( C ); + + cs_E = FLA_Obj_col_stride( E ); + // rs_E = FLA_Obj_row_stride( E ); + + m_p = FLA_Obj_length( p ); + + buff_p = ( int * ) FLA_INT_PTR( p ); + + void* C_mat = NULL; + void* E_mat = NULL; + if ( FLASH_Queue_get_malloc_managed_enabled_hip() ) + { + C_mat = FLA_Obj_buffer_at_view( C ); + E_mat = FLA_Obj_buffer_at_view( E ); + } + else + { + C_mat = C_hip; + E_mat = E_hip; + } + + switch ( datatype ){ + + case FLA_FLOAT: + { + float* buff_C = ( float * ) C_mat; + float* buff_E = ( float * ) E_mat; + + for ( i = 0; i < m_p; ++i ) + { + if ( buff_p[ i ] != 0 ) + rocblas_sswap( handle, n_C, + buff_C + 0*cs_C + i, cs_C, + buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E ); + } + break; + } + + case FLA_DOUBLE: + { + double* buff_C = ( double * ) C_mat; + double* buff_E = ( double * ) E_mat; + + for ( i = 0; i < m_p; ++i ) + { + if ( buff_p[ i ] != 0 ) + rocblas_dswap( handle, n_C, + buff_C + 0*cs_C + i, cs_C, + buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E ); + } + break; + } + + case FLA_COMPLEX: + { + rocblas_float_complex* buff_C = ( rocblas_float_complex * ) C_mat; + rocblas_float_complex* buff_E = ( rocblas_float_complex * ) E_mat; + + for ( i = 0; i < m_p; ++i ) + { + if ( buff_p[ i ] != 0 ) + rocblas_cswap( handle, n_C, + buff_C + 0*cs_C + i, cs_C, + buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E ); + } + break; + } + + case FLA_DOUBLE_COMPLEX: + { + rocblas_double_complex* buff_C = ( rocblas_double_complex * ) C_mat; + rocblas_double_complex* buff_E = ( rocblas_double_complex * ) E_mat; + + for ( i = 0; i < m_p; ++i ) + { + if ( buff_p[ i ] != 0 ) + rocblas_zswap( handle, n_C, + buff_C + 0*cs_C + i, cs_C, + buff_E + 0*cs_E + buff_p[ i ] - ( m_C - i ), cs_E ); + } + break; + } + + } + + return FLA_SUCCESS; +} + +#endif diff --git a/src/base/flamec/wrappers/lapack/hip/FLA_SA_FS_blk_hip.c b/src/base/flamec/wrappers/lapack/hip/FLA_SA_FS_blk_hip.c new file mode 100644 index 000000000..2e6414203 --- /dev/null +++ b/src/base/flamec/wrappers/lapack/hip/FLA_SA_FS_blk_hip.c @@ -0,0 +1,116 @@ +/* + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. + + This file is part of libflame and is available under the 3-Clause + BSD license, which can be found in the LICENSE file at the top-level + directory, or at http://opensource.org/licenses/BSD-3-Clause + +*/ + +#include "FLAME.h" + +#ifdef FLA_ENABLE_HIP + +#include "hip/hip_runtime_api.h" +#include "rocblas/rocblas.h" + +FLA_Error FLA_SA_FS_blk_hip( rocblas_handle handle, FLA_Obj L, + FLA_Obj D, void* D_hip, FLA_Obj p, FLA_Obj C, void* C_hip, + FLA_Obj E, void* E_hip, dim_t nb_alg ) +{ + + FLA_Obj LT, L0, + LB, L1, + L2; + + FLA_Obj DL, DR, D0, D1, D2; + + FLA_Obj pT, p0, + pB, p1, + p2; + + FLA_Obj CT, C0, + CB, C1, + C2; + + FLA_Obj L1_sqr, L1_rest; + + dim_t b; + + FLA_Part_2x1( L, <, + &LB, 0, FLA_TOP ); + + FLA_Part_1x2( D, &DL, &DR, 0, FLA_LEFT ); + + FLA_Part_2x1( p, &pT, + &pB, 0, FLA_TOP ); + + FLA_Part_2x1( C, &CT, + &CB, 0, FLA_TOP ); + + while ( FLA_Obj_length( LT ) < FLA_Obj_length( L ) ) + { + b = min( FLA_Obj_length( LB ), nb_alg ); + + FLA_Repart_2x1_to_3x1( LT, &L0, + /* ** */ /* ** */ + &L1, + LB, &L2, b, FLA_BOTTOM ); + + FLA_Repart_1x2_to_1x3( DL, /**/ DR, &D0, /**/ &D1, &D2, + b, FLA_RIGHT ); + + FLA_Repart_2x1_to_3x1( pT, &p0, + /* ** */ /* ** */ + &p1, + pB, &p2, b, FLA_BOTTOM ); + + FLA_Repart_2x1_to_3x1( CT, &C0, + /* ** */ /* ** */ + &C1, + CB, &C2, b, FLA_BOTTOM ); + + /*------------------------------------------------------------*/ + + FLA_Part_1x2( L1, &L1_sqr, &L1_rest, b, FLA_LEFT ); + + + FLA_SA_Apply_pivots_hip( handle, C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ), + E, E_hip, p1 ); + + FLA_Trsm_external_hip( handle, FLA_LEFT, FLA_LOWER_TRIANGULAR, + FLA_NO_TRANSPOSE, FLA_UNIT_DIAG, + FLA_ONE, L1_sqr, FLA_Obj_buffer_at_view ( L1_sqr ), + C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ) ); + + FLA_Gemm_external_hip( handle, FLA_NO_TRANSPOSE, FLA_NO_TRANSPOSE, + FLA_MINUS_ONE, D1, FLA_Obj_hip_buffer_at_view( D1, D_hip ), + C1, FLA_Obj_hip_buffer_at_view( C1, C_hip ), FLA_ONE, E, E_hip ); + + /*------------------------------------------------------------*/ + + FLA_Cont_with_3x1_to_2x1( <, L0, + L1, + /* ** */ /* ** */ + &LB, L2, FLA_TOP ); + + FLA_Cont_with_1x3_to_1x2( &DL, /**/ &DR, D0, D1, /**/ D2, + FLA_LEFT ); + + FLA_Cont_with_3x1_to_2x1( &pT, p0, + p1, + /* ** */ /* ** */ + &pB, p2, FLA_TOP ); + + FLA_Cont_with_3x1_to_2x1( &CT, C0, + C1, + /* ** */ /* ** */ + &CB, C2, FLA_TOP ); + } + + return FLA_SUCCESS; +} + +#endif