diff --git a/src/C-interface/dense/bml_parallel_dense.c b/src/C-interface/dense/bml_parallel_dense.c index eecf6aee4..e701554be 100644 --- a/src/C-interface/dense/bml_parallel_dense.c +++ b/src/C-interface/dense/bml_parallel_dense.c @@ -204,13 +204,24 @@ bml_mpi_bcast_matrix_dense( const int root, MPI_Comm comm) { - // create MPI data type to avoid multiple messages - MPI_Datatype mpi_data_type; - bml_mpi_type_create_struct_dense(A, &mpi_data_type); - - MPI_Bcast(A, 1, mpi_data_type, root, comm); - - MPI_Type_free(&mpi_data_type); + switch (A->matrix_precision) + { + case single_real: + return bml_mpi_bcast_matrix_dense_single_real(A, root, comm); + break; + case double_real: + return bml_mpi_bcast_matrix_dense_double_real(A, root, comm); + break; + case single_complex: + return bml_mpi_bcast_matrix_dense_single_complex(A, root, comm); + break; + case double_complex: + return bml_mpi_bcast_matrix_dense_double_complex(A, root, comm); + break; + default: + LOG_ERROR("unknown precision\n"); + break; + } } #endif diff --git a/src/C-interface/dense/bml_parallel_dense.h b/src/C-interface/dense/bml_parallel_dense.h index 7bf44c3f5..4e9ea34a7 100644 --- a/src/C-interface/dense/bml_parallel_dense.h +++ b/src/C-interface/dense/bml_parallel_dense.h @@ -146,6 +146,23 @@ void bml_mpi_bcast_matrix_dense( bml_matrix_dense_t * A, const int root, MPI_Comm comm); + +void bml_mpi_bcast_matrix_dense_single_real( + bml_matrix_dense_t * A, + const int root, + MPI_Comm comm); +void bml_mpi_bcast_matrix_dense_double_real( + bml_matrix_dense_t * A, + const int root, + MPI_Comm comm); +void bml_mpi_bcast_matrix_dense_single_complex( + bml_matrix_dense_t * A, + const int root, + MPI_Comm comm); +void bml_mpi_bcast_matrix_dense_double_complex( + bml_matrix_dense_t * A, + const int root, + MPI_Comm comm); #endif #endif diff --git a/src/C-interface/dense/bml_parallel_dense_typed.c b/src/C-interface/dense/bml_parallel_dense_typed.c index ca376a8db..fa72f054a 100644 --- a/src/C-interface/dense/bml_parallel_dense_typed.c +++ b/src/C-interface/dense/bml_parallel_dense_typed.c @@ -22,11 +22,6 @@ #include #endif -#ifdef BML_USE_MAGMA - // buffer on CPU to be used for communications -static MAGMA_T *A_matrix_buffer; -#endif - /** Gather a bml matrix across MPI ranks. * * \ingroup parallel_group @@ -164,4 +159,27 @@ bml_matrix_dense_t return A_bml; } +void TYPED_FUNC( + bml_mpi_bcast_matrix_dense) ( + bml_matrix_dense_t * A, + const int root, + MPI_Comm comm) +{ +#ifdef BML_USE_MAGMA + MAGMA_T *A_matrix = bml_allocate_memory(sizeof(MAGMA_T) * A->N * A->N); + MAGMA(getmatrix) (A->N, A->N, A->matrix, A->ld, A_matrix, A->N, + bml_queue()); +#else + REAL_T *A_matrix = A->matrix; +#endif + + MPI_Bcast(A_matrix, A->N * A->N, MPI_T, root, comm); + +#ifdef BML_USE_MAGMA + MAGMA(setmatrix) (A->N, A->N, A_matrix, A->N, A->matrix, A->ld, + bml_queue()); + bml_free_memory(A_matrix); +#endif +} + #endif