Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify to check if alpha is in host memory. #1356

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion library/include/hipblaslt-ext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ namespace hipblaslt_ext
int8_t beta[16]; //!< The beta value.
// Epilogue inputs
void* scaleA; //!< The scaleA input pointer.
void* scaleB; //!< The scaleA input pointer.
void* scaleB; //!< The scaleB input pointer.
void* scaleC; //!< The scaleC input pointer.
void* scaleD; //!< The scaleD input pointer.
void* scaleAlphaVec; //!< The scaleAlpha vector input pointer.
Expand Down
18 changes: 11 additions & 7 deletions library/src/amd_detail/rocblaslt/include/rocblaslt-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ typedef enum rocblaslt_epilogue_
typedef enum rocblaslt_pointer_mode_
{
rocblaslt_pointer_mode_host = 0, /**< scalar pointers are in host memory. */
rocblaslt_pointer_mode_device = 1 /**< scalar pointers are in device memory. */
rocblaslt_pointer_mode_device = 1, /**< scalar pointers are in device memory. */
rocblaslt_pointer_mode_alpha_device_vector_beta_host
= 4 /** alpha pointer targets a device memory vector of length equal to the number of rows of matrix D, and beta is a single value in host memory. */
} rocblaslt_pointer_mode;

/*! \ingroup types_module
Expand Down Expand Up @@ -265,16 +267,18 @@ typedef enum rocblaslt_compute_type_
rocblaslt_compute_f64_pedantic = 8, /**< compute will be exactly 64-bit precision */
rocblaslt_compute_i32 = 9, /**< 32-bit integer precision. */
rocblaslt_compute_i32_pedantic = 10, /**< compute will be exactly 32-bit integer precision */
rocblaslt_compute_f32_fast_f8_fnuz = 100, /**< 32-bit input can use fp8 compute */
rocblaslt_compute_f32_fast_bf8_fnuz = 101, /**< 32-bit input can use bf8 compute */
rocblaslt_compute_f32_fast_f8_fnuz = 100, /**< 32-bit input can use fp8 compute */
rocblaslt_compute_f32_fast_bf8_fnuz = 101, /**< 32-bit input can use bf8 compute */
rocblaslt_compute_f32_fast_f8bf8_fnuz
= 102, /**< 32-bit input can use fp8 for A and bf8 for B compute */
rocblaslt_compute_f32_fast_bf8f8_fnuz
= 103, /**< 32-bit input can use bf8 for A and fp8 for B compute */
rocblaslt_compute_f32_fast_f8_ocp = 104, /**< 32-bit input can use fp8 compute */
rocblaslt_compute_f32_fast_bf8_ocp = 105, /**< 32-bit input can use bf8 compute */
rocblaslt_compute_f32_fast_f8bf8_ocp = 106, /**< 32-bit input can use fp8 for A and bf8 for B compute */
rocblaslt_compute_f32_fast_bf8f8_ocp = 107, /**< 32-bit input can use bf8 for A and fp8 for B compute */
rocblaslt_compute_f32_fast_f8_ocp = 104, /**< 32-bit input can use fp8 compute */
rocblaslt_compute_f32_fast_bf8_ocp = 105, /**< 32-bit input can use bf8 compute */
rocblaslt_compute_f32_fast_f8bf8_ocp
= 106, /**< 32-bit input can use fp8 for A and bf8 for B compute */
rocblaslt_compute_f32_fast_bf8f8_ocp
= 107, /**< 32-bit input can use bf8 for A and fp8 for B compute */
} rocblaslt_compute_type;

/*! \ingroup types_module
Expand Down
19 changes: 10 additions & 9 deletions library/src/amd_detail/rocblaslt/src/include/handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ struct _rocblaslt_matmul_desc
hipblasOperation_t op_B = HIPBLAS_OP_N;
// epilogue operation
rocblaslt_epilogue epilogue = ROCBLASLT_EPILOGUE_DEFAULT;
// alpha,beta pointer mode
rocblaslt_pointer_mode pointermode = rocblaslt_pointer_mode_host;
// bias vector pointer
void* bias = nullptr;
void* scaleA = nullptr;
void* scaleB = nullptr;
void* scaleC = nullptr;
void* scaleD = nullptr;
void* scaleE = nullptr;
void* pointermode = nullptr;
void* amaxD = nullptr;
hipDataType bias_type = HIPBLASLT_DATATYPE_INVALID;
void* bias = nullptr;
void* scaleA = nullptr;
void* scaleB = nullptr;
void* scaleC = nullptr;
void* scaleD = nullptr;
void* scaleE = nullptr;
void* amaxD = nullptr;
hipDataType bias_type = HIPBLASLT_DATATYPE_INVALID;
// E
void* e = nullptr;
int64_t lde = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,25 @@ inline rocblaslt_status validateMatmulDescrArgs(rocblaslt_handle handle,
/*******************************************************************************
* Validate Matmul Arguments
******************************************************************************/
inline rocblaslt_status validateMatmulArgs(int64_t m,
int64_t n,
int64_t k,
const void* alpha,
const void* a,
const void* b,
const void* beta,
const void* c,
const void* d,
int num_batches_a = 1,
int num_batches_b = 1,
int num_batches_c = 1,
int num_batches_d = 1,
int64_t batch_stride_a = 0,
int64_t batch_stride_b = 0,
int64_t batch_stride_c = 0,
int64_t batch_stride_d = 0)
inline rocblaslt_status validateMatmulArgs(int64_t m,
int64_t n,
int64_t k,
const void* alpha,
const void* a,
const void* b,
const void* beta,
const void* c,
const void* d,
int num_batches_a = 1,
int num_batches_b = 1,
int num_batches_c = 1,
int num_batches_d = 1,
int64_t batch_stride_a = 0,
int64_t batch_stride_b = 0,
int64_t batch_stride_c = 0,
int64_t batch_stride_d = 0,
const rocblaslt_pointer_mode& pointermode
= rocblaslt_pointer_mode_host)
{
// sizes must not be negative
if(batch_stride_a < 0 || batch_stride_b < 0 || batch_stride_c < 0 || batch_stride_d < 0)
Expand Down Expand Up @@ -183,9 +185,10 @@ inline rocblaslt_status validateMatmulArgs(int64_t m,
if(!beta)
return rocblaslt_status_invalid_pointer;

// Update for the valid case: ((alpha_in_host && alpha=0) && (A=NULL || B=NULL))
bool alpha_A_B_violation = (!alpha || ((pointermode || (*((float*)alpha))) && (!a || !b)));
// pointers must be valid
// Update for the valid case: (alpha=0 && (A=NULL || B=NULL))
if(n && ((k && (!alpha || ((*((float*)alpha)) && (!a || !b)))) || !c || !d))
if(n && ((k && alpha_A_B_violation) || !c || !d))
return rocblaslt_status_invalid_pointer;

return rocblaslt_status_continue;
Expand Down Expand Up @@ -339,7 +342,8 @@ inline rocblaslt_status rocblaslt_matmul_valid_args(const rocblaslt_matmul_desc
batch_stride_a,
batch_stride_b,
batch_stride_c,
batch_stride_d);
batch_stride_d,
matmul_descr->pointermode);

if(status == rocblaslt_status_continue)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ rocblaslt_status rocblaslt_matmul_desc_get_attribute(rocblaslt_matmul_desc
log_error(__func__, "invalid buf size", sizeInBytes);
return rocblaslt_status_invalid_value;
}
memcpy(buf, &matmulDesc->pointermode, sizeof(void*));
memcpy(buf, &matmulDesc->pointermode, sizeof(int32_t));
break;
case ROCBLASLT_MATMUL_DESC_BIAS_DATA_TYPE:
if(sizeWritten)
Expand Down
10 changes: 6 additions & 4 deletions library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ rocblaslt_status
batch_stride_a,
batch_stride_b,
batch_stride_c,
batch_stride_d);
batch_stride_d,
matmul_descr[i]->pointermode);
if(validArgs == rocblaslt_status_success)
continue;

Expand Down Expand Up @@ -650,10 +651,11 @@ rocblaslt_status rocblaslt_matmul(rocblaslt_handle handle,
return rocblaslt_status_invalid_handle;
}

// Update for the valid case: ((alpha_in_host && alpha=0) && (A=NULL || B=NULL))
bool alpha_A_B_violation
= (!alpha || ((matmul_descr->pointermode || (*((float*)alpha))) && (!A || !B)));
// Check if pointer is valid
// Update for the valid case: (alpha=0 && (A=NULL || B=NULL))
if(alpha == nullptr || beta == nullptr || C == nullptr || D == nullptr
|| ((*((float*)alpha)) && (A == nullptr || B == nullptr)))
if(alpha == nullptr || beta == nullptr || C == nullptr || D == nullptr || alpha_A_B_violation)
{
log_error(__func__, "invalid data pointer");
return rocblaslt_status_invalid_pointer;
Expand Down