diff --git a/include/flexflow/accessor.h b/include/flexflow/accessor.h index 65ab33b513..121140c926 100644 --- a/include/flexflow/accessor.h +++ b/include/flexflow/accessor.h @@ -5,16 +5,25 @@ #if defined(FF_USE_CUDA) #include +#include #elif defined(FF_USE_HIP_CUDA) #include +#include #elif defined(FF_USE_HIP_ROCM) #include +#include #endif // using namespace Legion; namespace FlexFlow { +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +typedef __nv_bfloat16 __ff_bfloat16; +#elif defined(FF_USE_HIP_ROCM) +typedef hip_bfloat16 __ff_bfloat16; +#endif + template using AccessorRO = Legion::FieldAccessor>; @@ -61,6 +70,7 @@ class GenericTensorAccessorW { float *get_float_ptr() const; double *get_double_ptr() const; half *get_half_ptr() const; + __ff_bfloat16 *get_bfloat16_ptr() const; char *get_byte_ptr() const; DataType data_type; Legion::Domain domain; @@ -80,6 +90,7 @@ class GenericTensorAccessorR { float const *get_float_ptr() const; double const *get_double_ptr() const; half const *get_half_ptr() const; + __ff_bfloat16 const *get_bfloat16_ptr() const; char const *get_byte_ptr() const; DataType data_type; Legion::Domain domain; diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index 512645e624..49d5f4d1e4 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -35,6 +35,7 @@ enum DataType { DT_DOUBLE = 45, DT_INT4 = 46, DT_INT8 = 47, + DT_BF16 = 48, DT_NONE = 49, }; @@ -269,4 +270,10 @@ enum { PARALLEL_TENSOR_GUID_FIRST_VALID = 4000000, NODE_GUID_FIRST_VALID = 5000000, }; + +enum InferencePrecision { + INFERENCE_FLOAT = 800, + INFERENCE_HALF = 801, + INFERENCE_BFLOAT16 = 802, +}; #endif // _FLEXFLOW_CONST_H_ diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh index d1e0e050b2..dfc99bc612 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -23,6 +23,25 @@ struct half8 { half c; half d; }; + +struct __nv_bfloat164 { + __nv_bfloat16 x; + __nv_bfloat16 y; + __nv_bfloat16 z; + __nv_bfloat16 w; +}; + +struct __nv_bfloat168 { + __nv_bfloat16 x; + __nv_bfloat16 y; + __nv_bfloat16 z; + __nv_bfloat16 w; + __nv_bfloat16 a; + __nv_bfloat16 b; + __nv_bfloat16 c; + __nv_bfloat16 d; +}; + struct float8 { float x; float y; @@ -61,6 +80,18 @@ template <> struct VEC_K { using Type = half4; }; +template <> +struct VEC_K<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template <> +struct VEC_K<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template <> +struct VEC_K<__nv_bfloat16, 4> { + using Type = __nv_bfloat164; +}; // data type for QK production template @@ -95,6 +126,23 @@ struct Vec_fp32_ { using Type = float8; }; +template <> +struct Vec_fp32_<__nv_bfloat16> { + using Type = float; +}; +template <> +struct Vec_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct Vec_fp32_<__nv_bfloat164> { + using Type = float4; +}; +template <> +struct Vec_fp32_<__nv_bfloat168> { + using Type = float8; +}; + template struct VEC_V {}; template <> @@ -105,6 +153,10 @@ template <> struct VEC_V { using Type = half8; }; +template <> +struct VEC_V<__nv_bfloat16> { + using Type = __nv_bfloat168; +}; ////////////////data structures half/////////////// @@ -331,6 +383,42 @@ inline __device__ float8 cast_to_float(half8 u) { return tmp; } +inline __device__ float cast_to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(__nv_bfloat162 u) { + float2 tmp; + tmp.x = __bfloat162float(u.x); + tmp.y = __bfloat162float(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(__nv_bfloat164 u) { + float4 tmp; + tmp.x = __bfloat162float(u.x); + tmp.y = __bfloat162float(u.y); + tmp.z = __bfloat162float(u.z); + tmp.w = __bfloat162float(u.w); + return tmp; +} +inline __device__ float8 cast_to_float(__nv_bfloat168 u) { + float8 tmp; + tmp.x = __bfloat162float(u.x); + tmp.y = __bfloat162float(u.y); + tmp.z = __bfloat162float(u.z); + tmp.w = __bfloat162float(u.w); + tmp.a = __bfloat162float(u.a); + tmp.b = __bfloat162float(u.b); + tmp.c = __bfloat162float(u.c); + tmp.d = __bfloat162float(u.d); + return tmp; +} + inline __device__ void convert_from_float(float4 &dst, float4 src) { dst = src; } @@ -369,6 +457,32 @@ inline __device__ void convert_from_float(half &dst, float src) { dst = __float2half(src); } +// +inline __device__ void convert_from_float(__nv_bfloat164 &dst, float4 src) { + dst.x = __float2bfloat16(src.x); + dst.y = __float2bfloat16(src.y); + dst.z = __float2bfloat16(src.z); + dst.w = __float2bfloat16(src.w); +} + +inline __device__ void convert_from_float(__nv_bfloat168 &dst, float8 src) { + dst.x = __float2bfloat16(src.x); + dst.y = __float2bfloat16(src.y); + dst.z = __float2bfloat16(src.z); + dst.w = __float2bfloat16(src.w); + dst.a = __float2bfloat16(src.a); + dst.b = __float2bfloat16(src.b); + dst.c = __float2bfloat16(src.c); + dst.d = __float2bfloat16(src.d); +} +inline __device__ void convert_from_float(__nv_bfloat162 &dst, float2 src) { + dst.x = __float2bfloat16(src.x); + dst.y = __float2bfloat16(src.y); +} +inline __device__ void convert_from_float(__nv_bfloat16 &dst, float src) { + dst = __float2bfloat16(src); +} + //////////////////////////////////////utils/////////////////////////////////////////////// template diff --git a/include/flexflow/utils/cuda_helper.h b/include/flexflow/utils/cuda_helper.h index f8bf67b3e1..f3038860d7 100644 --- a/include/flexflow/utils/cuda_helper.h +++ b/include/flexflow/utils/cuda_helper.h @@ -161,6 +161,20 @@ T *download_tensor(T const *ptr, size_t num_elements); template bool download_tensor(T const *ptr, T *dst, size_t num_elements); +// data type for cublasgemm +template +struct cublasAlphaBetaType { + using type = float; // default +}; +template <> +struct cublasAlphaBetaType { + using type = half; +}; +template <> +struct cublasAlphaBetaType<__nv_bfloat16> { + using type = float; +}; + cudnnStatus_t cudnnSetTensorDescriptorFromDomain(cudnnTensorDescriptor_t tensor, Legion::Domain domain, DataType data_type = DT_FLOAT); diff --git a/include/flexflow/utils/file_loader.h b/include/flexflow/utils/file_loader.h index 646eb18da2..39b46955b5 100644 --- a/include/flexflow/utils/file_loader.h +++ b/include/flexflow/utils/file_loader.h @@ -38,6 +38,8 @@ class FileDataLoader { template void load_single_weight_tensor(FFModel *ff, Layer *l, int weight_idx); + void load_single_weight_tensor_b16(FFModel *ff, Layer *l, int weight_idx); + void load_quantization_weight(FFModel *ff, Layer *l, int weight_idx); void load_weights(FFModel *ff); @@ -46,6 +48,8 @@ class FileDataLoader { ParallelTensor position_pt, int max_seq_length, int offset); + // template + // void load_from_file(DT *ptr, size_t size, std::string filepath); private: int num_heads, num_kv_heads, tensor_parallelism_degree; diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index f88af3bc43..75f5a7b2bc 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -40,7 +40,9 @@ void parse_input_args(char **argv, int argc, FilePaths &paths, std::string &llm_model_name, + DataType &data_type, bool &use_full_precision, + bool &use_bfloat16_precision, bool &verbose, bool &do_sample, float &temperature, @@ -74,6 +76,12 @@ void parse_input_args(char **argv, } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; + data_type = DT_FLOAT; + continue; + } + if (!strcmp(argv[i], "--use-bfloat16-precision")) { + use_bfloat16_precision = true; + data_type = DT_BF16; continue; } // verbose logging to stdout @@ -126,7 +134,9 @@ void FlexFlow::top_level_task(Task const *task, } FilePaths file_paths; std::string llm_model_name; + DataType data_type = DT_HALF; bool use_full_precision = false; + bool use_bfloat16_precision = false; bool verbose = false; bool do_sample = false; float temperature = 0.0f; @@ -142,7 +152,9 @@ void FlexFlow::top_level_task(Task const *task, argc, file_paths, llm_model_name, + data_type, use_full_precision, + use_bfloat16_precision, verbose, do_sample, temperature, @@ -159,11 +171,15 @@ void FlexFlow::top_level_task(Task const *task, {file_paths.cache_folder_path, "configs", llm_model_name, "config.json"}); std::string tokenizer_filepath = join_path({file_paths.cache_folder_path, "tokenizers", llm_model_name}); - std::string weights_filepath = - join_path({file_paths.cache_folder_path, - "weights", - llm_model_name, - use_full_precision ? "full-precision" : "half-precision"}); + + // bfloat16 shares same weight file with float32 + std::string weights_filepath = join_path( + {file_paths.cache_folder_path, + "weights", + llm_model_name, + use_full_precision + ? "full-precision" + : (use_bfloat16_precision ? "full-precision" : "half-precision")}); std::ifstream config_file_handle(config_filepath); if (!config_file_handle.good()) { std::cout << "Model config file " << config_filepath << " not found." @@ -220,33 +236,27 @@ void FlexFlow::top_level_task(Task const *task, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_type == ModelType::OPT) { - OPT::create_opt_model(model, - config_filepath, - weights_filepath, - INC_DECODING_MODE, - use_full_precision); + OPT::create_opt_model( + model, config_filepath, weights_filepath, INC_DECODING_MODE, data_type); } else if (model_type == ModelType::FALCON) { - FALCON::create_falcon_model(model, - config_filepath, - weights_filepath, - INC_DECODING_MODE, - use_full_precision); + FALCON::create_falcon_model( + model, config_filepath, weights_filepath, INC_DECODING_MODE, data_type); } else if (model_type == ModelType::STARCODER) { STARCODER::create_starcoder_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_type == ModelType::MPT) { MPT::create_mpt_model(model, config_filepath, weights_filepath, INC_DECODING_MODE, generationConfig, - use_full_precision); + data_type); } else { assert(false && "unknow model type"); } diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index e00f4e9cfd..5396b32fbd 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -24,7 +24,7 @@ void FALCON::create_falcon_model(FFModel &ff, std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - bool use_full_precision) { + DataType data_type) { FalconConfig falcon_config(model_config_file_path); falcon_config.print(); @@ -54,7 +54,7 @@ void FALCON::create_falcon_model(FFModel &ff, falcon_config.vocab_size, falcon_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "word_embeddings"); @@ -248,7 +248,7 @@ void FALCON::create_falcon_model(FFModel &ff, falcon_config.hidden_size, falcon_config.hidden_size / falcon_config.n_head, ff.config.tensor_parallelism_degree, - use_full_precision); + true); InferenceManager *im = InferenceManager::get_inference_manager(); im->register_model_weights_loader(&ff, fileloader); @@ -266,7 +266,7 @@ void FALCON::create_falcon_model(FFModel &ff, falcon_config.hidden_size / falcon_config.n_head, ff.config.tensor_parallelism_degree); std::cout << "------load weights ----------" << std::endl; - fileloader.load_weights(&ff, use_full_precision); + fileloader.load_weights(&ff, false); std::cout << "------load weight finished----------" << std::endl; // init operators diff --git a/inference/models/falcon.h b/inference/models/falcon.h index fce2dade3f..7477b3523d 100644 --- a/inference/models/falcon.h +++ b/inference/models/falcon.h @@ -94,7 +94,7 @@ class FALCON { std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 14b8c31fa1..fcd5c64bfe 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -25,7 +25,7 @@ void LLAMA::create_llama_model(FFModel &ff, std::string const &weight_file_path, InferenceMode mode, GenerationConfig generation_config, - bool use_full_precision) { + DataType data_type) { // do not apply cpu offload in beam search model. LLAMAConfig llama_config(model_config_file_path); llama_config.print(); @@ -55,7 +55,7 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.vocab_size, llama_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "tok_embeddings"); @@ -273,7 +273,7 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.hidden_size, llama_config.hidden_size / llama_config.num_attention_heads, ff.config.tensor_parallelism_degree, - use_full_precision); + true); InferenceManager *im = InferenceManager::get_inference_manager(); im->register_model_weights_loader(&ff, fileloader); diff --git a/inference/models/llama.h b/inference/models/llama.h index ba1f0236f9..b0f8c6b207 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -83,7 +83,7 @@ class LLAMA { std::string const &weight_file_path, InferenceMode mode, GenerationConfig generation_config, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index 7e8fc8358f..c5572deb47 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -25,7 +25,7 @@ void MPT::create_mpt_model(FFModel &ff, std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision) { + DataType data_type) { MPTConfig mpt_config(model_config_file_path); mpt_config.print(); @@ -55,7 +55,7 @@ void MPT::create_mpt_model(FFModel &ff, mpt_config.vocab_size, mpt_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "transformer_wte"); @@ -255,7 +255,7 @@ void MPT::create_mpt_model(FFModel &ff, mpt_config.hidden_size, mpt_config.hidden_size / mpt_config.n_heads, ff.config.tensor_parallelism_degree, - use_full_precision); + true); InferenceManager *im = InferenceManager::get_inference_manager(); im->register_model_weights_loader(&ff, fileloader); @@ -271,7 +271,7 @@ void MPT::create_mpt_model(FFModel &ff, mpt_config.hidden_size, mpt_config.hidden_size / mpt_config.n_heads, ff.config.tensor_parallelism_degree); - fileloader.load_weights(&ff, use_full_precision); + fileloader.load_weights(&ff, false); im->init_operators_inference(&ff); #endif } diff --git a/inference/models/mpt.h b/inference/models/mpt.h index 08597e1d75..510eaae5b0 100644 --- a/inference/models/mpt.h +++ b/inference/models/mpt.h @@ -70,7 +70,7 @@ class MPT { std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 3ff4c96fdf..23f910315e 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -24,7 +24,7 @@ void OPT::create_opt_model(FFModel &ff, std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - bool use_full_precision) { + DataType data_type) { OPTConfig opt_config(model_config_file_path); opt_config.print(); @@ -58,20 +58,19 @@ void OPT::create_opt_model(FFModel &ff, opt_config.vocab_size, opt_config.word_embed_proj_dim, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "embed_tokens"); - Tensor positional_embedding = - ff.embedding(position_input, - opt_config.max_position_embeddings, - opt_config.hidden_size, - AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, - NULL, - embed_init, - "embed_positions"); + Tensor positional_embedding = ff.embedding(position_input, + opt_config.max_position_embeddings, + opt_config.hidden_size, + AGGR_MODE_NONE, + data_type, + NULL, + embed_init, + "embed_positions"); Tensor fc2 = nullptr, added = nullptr; Tensor res_ln_outputs[2] = {nullptr, nullptr}; @@ -263,7 +262,7 @@ void OPT::create_opt_model(FFModel &ff, opt_config.hidden_size, opt_config.hidden_size / opt_config.num_attention_heads, ff.config.tensor_parallelism_degree, - use_full_precision); + true); InferenceManager *im = InferenceManager::get_inference_manager(); im->register_model_weights_loader(&ff, fileloader); @@ -280,7 +279,7 @@ void OPT::create_opt_model(FFModel &ff, opt_config.hidden_size / opt_config.num_attention_heads, ff.config.tensor_parallelism_degree); - fileloader.load_weights(&ff, use_full_precision); + fileloader.load_weights(&ff, false); std::cout << "------finished loading weights----------" << std::endl; im->init_operators_inference(&ff); #endif diff --git a/inference/models/opt.h b/inference/models/opt.h index 7c736a26d1..cc1cdf576b 100644 --- a/inference/models/opt.h +++ b/inference/models/opt.h @@ -97,7 +97,7 @@ class OPT { std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index 2327c86119..71dbfb8fbb 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -26,7 +26,7 @@ void STARCODER::create_starcoder_model( std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision) { + DataType data_type) { // do not apply cpu offload in beam search model. STARCODERConfig startcoder_config(model_config_file_path); startcoder_config.print(); @@ -63,7 +63,7 @@ void STARCODER::create_starcoder_model( startcoder_config.vocab_size, startcoder_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "transformer_wte"); @@ -73,7 +73,7 @@ void STARCODER::create_starcoder_model( startcoder_config.max_position_embeddings, startcoder_config.hidden_size, AGGR_MODE_NONE, - use_full_precision ? DT_FLOAT : DT_HALF, + data_type, NULL, embed_init, "transformer_wpe"); @@ -230,7 +230,7 @@ void STARCODER::create_starcoder_model( startcoder_config.hidden_size, startcoder_config.hidden_size / startcoder_config.num_attention_heads, ff.config.tensor_parallelism_degree, - use_full_precision); + true); im->register_model_weights_loader(&ff, fileloader); #ifdef DEADCODE // Compile the model diff --git a/inference/models/starcoder.h b/inference/models/starcoder.h index 0e9577d569..fce7a27467 100644 --- a/inference/models/starcoder.h +++ b/inference/models/starcoder.h @@ -71,7 +71,7 @@ class STARCODER { std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision = false); + DataType data_type); }; }; // namespace FlexFlow diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 7578721dd0..9fe686f8e7 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -58,7 +58,9 @@ void parse_input_args(char **argv, int argc, FilePaths &paths, ModelNames &model_names, + DataType &data_type, bool &use_full_precision, + bool &use_bfloat16_precision, bool &verbose, int &max_requests_per_batch, int &max_tokens_per_batch, @@ -98,8 +100,13 @@ void parse_input_args(char **argv, } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; + data_type = DT_FLOAT; continue; } + if (!strcmp(argv[i], "--use-bfloat16-precision")) { + use_bfloat16_precision = true; + data_type = DT_BF16; + } // verbose logging to stdout if (!strcmp(argv[i], "--verbose")) { verbose = true; @@ -130,7 +137,8 @@ void parse_input_args(char **argv, void get_model_meta(FilePaths &file_paths, ModelMeta &model_metadata, - bool use_full_precision) { + bool use_full_precision, + bool use_bfloat16_precision) { if (model_metadata.model_names.llm_model_name.empty() || model_metadata.model_names.ssm_model_names.size() == 0) { assert(false && "SpecInfer needs at least one LLM and one SSM for " @@ -145,11 +153,13 @@ void get_model_meta(FilePaths &file_paths, join_path({file_paths.cache_folder_path, "tokenizers", model_metadata.model_names.llm_model_name}); - model_metadata.llm_weights_path = - join_path({file_paths.cache_folder_path, - "weights", - model_metadata.model_names.llm_model_name, - use_full_precision ? "full-precision" : "half-precision"}); + model_metadata.llm_weights_path = join_path( + {file_paths.cache_folder_path, + "weights", + model_metadata.model_names.llm_model_name, + use_full_precision + ? "full-precision" + : (use_bfloat16_precision ? "full-precision" : "half-precision")}); std::ifstream llm_config_file_handle(model_metadata.llm_model_config_path); if (!llm_config_file_handle.good()) { @@ -196,11 +206,13 @@ void get_model_meta(FilePaths &file_paths, "config.json"}); std::string ssm_tokenizer_path = join_path({file_paths.cache_folder_path, "tokenizers", ssm_model_name}); - std::string ssm_weights_path = - join_path({file_paths.cache_folder_path, - "weights", - ssm_model_name, - use_full_precision ? "full-precision" : "half-precision"}); + std::string ssm_weights_path = join_path( + {file_paths.cache_folder_path, + "weights", + ssm_model_name, + use_full_precision + ? "full-precision" + : (use_bfloat16_precision ? "full-precision" : "half-precision")}); std::ifstream ssm_config_file_handle(ssm_config_path); if (!ssm_config_file_handle.good()) { @@ -265,7 +277,9 @@ void FlexFlow::top_level_task(Task const *task, FFConfig ffconfig; FilePaths file_paths; ModelMeta model_metadata; + DataType data_type = DT_HALF; bool use_full_precision = false; + bool use_bfloat16_precision = false; bool verbose = false; int max_requests_per_batch = 16; int max_tokens_per_batch = 256; @@ -278,13 +292,16 @@ void FlexFlow::top_level_task(Task const *task, argc, file_paths, model_metadata.model_names, + data_type, use_full_precision, + use_bfloat16_precision, verbose, max_requests_per_batch, max_tokens_per_batch, max_sequence_length); - get_model_meta(file_paths, model_metadata, use_full_precision); + get_model_meta( + file_paths, model_metadata, use_full_precision, use_bfloat16_precision); assert(ffconfig.data_parallelism_degree * ffconfig.tensor_parallelism_degree * ffconfig.pipeline_parallelism_degree == @@ -314,26 +331,26 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.llm_weights_path, TREE_VERIFY_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_metadata.llm_model_type == ModelType::OPT) { OPT::create_opt_model(tree_model, model_metadata.llm_model_config_path, model_metadata.llm_weights_path, TREE_VERIFY_MODE, - use_full_precision); + data_type); } else if (model_metadata.llm_model_type == ModelType::FALCON) { FALCON::create_falcon_model(tree_model, model_metadata.llm_model_config_path, model_metadata.llm_weights_path, TREE_VERIFY_MODE, - use_full_precision); + data_type); } else if (model_metadata.llm_model_type == ModelType::MPT) { MPT::create_mpt_model(tree_model, model_metadata.llm_model_config_path, model_metadata.llm_weights_path, TREE_VERIFY_MODE, generationConfig, - use_full_precision); + data_type); } else { assert(false && "Invalid LLM model type passed (or no type was passed)."); } @@ -358,27 +375,27 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.ssm_model_weights_paths[ssm_id], BEAM_SEARCH_MODE, generationConfig, - use_full_precision); + data_type); } else if (model_metadata.ssm_model_types[ssm_id] == ModelType::OPT) { OPT::create_opt_model(beam_model, model_metadata.ssm_model_config_paths[ssm_id], model_metadata.ssm_model_weights_paths[ssm_id], BEAM_SEARCH_MODE, - use_full_precision); + data_type); } else if (model_metadata.ssm_model_types[ssm_id] == ModelType::FALCON) { FALCON::create_falcon_model( beam_model, model_metadata.ssm_model_config_paths[ssm_id], model_metadata.ssm_model_weights_paths[ssm_id], BEAM_SEARCH_MODE, - use_full_precision); + data_type); } else if (model_metadata.ssm_model_types[ssm_id] == ModelType::MPT) { MPT::create_mpt_model(beam_model, model_metadata.ssm_model_config_paths[ssm_id], model_metadata.ssm_model_weights_paths[ssm_id], BEAM_SEARCH_MODE, generationConfig, - use_full_precision); + data_type); } else { assert(false && "Invalid SSM model type passed."); } diff --git a/inference/utils/download_hf_model.py b/inference/utils/download_hf_model.py index 03fc8e1633..2a23527fb7 100644 --- a/inference/utils/download_hf_model.py +++ b/inference/utils/download_hf_model.py @@ -30,6 +30,11 @@ def parse_args(): action="store_true", help="Only download the half precision version of the weights", ) + group.add_argument( + "--bfloat16-precision-only", + action="store_true", + help="Only download the bfloat16 precision version of the weights", + ) args = parser.parse_args() return args @@ -39,8 +44,10 @@ def main(args): data_types = ff.DataType.DT_FLOAT elif args.half_precision_only: data_types = ff.DataType.DT_HALF + elif args.bfloat16_precision_only: + data_types = ff.DataType.DT_BF16 else: - data_types = (ff.DataType.DT_FLOAT, ff.DataType.DT_HALF) + data_types = (ff.DataType.DT_FLOAT, ff.DataType.DT_HALF, ff.DataType.DT_BF16) for model_name in args.model_names: for data_type in data_types: diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index 900ab48bcd..337f872dd8 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -15,7 +15,7 @@ from flexflow.core import * from .base import FlexFlowModel import random - +import torch class LLAMAConfig: def __init__(self, hf_config): diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index d1a935e5fc..3231b222b4 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -113,7 +113,7 @@ def __init__( self.config_class, ) = self.__get_ff_model_type() self.data_type = data_type - assert self.data_type == DataType.DT_HALF or self.data_type == DataType.DT_FLOAT + assert self.data_type == DataType.DT_HALF or self.data_type == DataType.DT_FLOAT or self.data_type == DataType.DT_BF16 self.cache_path = cache_path if len(cache_path) > 0 else "~/.cache/flexflow" self.refresh_cache = refresh_cache self.output_file = output_file @@ -177,6 +177,8 @@ def download_hf_weights_if_needed(self): torch.set_default_tensor_type(torch.HalfTensor) elif self.data_type == DataType.DT_FLOAT: torch.set_default_tensor_type(torch.FloatTensor) + elif self.data_type == DataType.DT_BF16: + torch.set_default_tensor_type(torch.BFloat16Tensor) else: assert False, "Data type not yet supported -- cannot download weights!" @@ -187,6 +189,7 @@ def download_hf_weights_if_needed(self): self.model_name.lower(), "full-precision" if self.data_type == DataType.DT_FLOAT + else "bfloat16-precision" if self.data_type == DataType.DT_BF16 else "half-precision", ) if self.refresh_cache: diff --git a/python/flexflow/type.py b/python/flexflow/type.py index 994a85f57e..04de9fa7d6 100644 --- a/python/flexflow/type.py +++ b/python/flexflow/type.py @@ -35,6 +35,7 @@ class DataType(Enum): DT_HALF = 43 DT_FLOAT = 44 DT_DOUBLE = 45 + DT_BF16 = 48 DT_NONE = 49 diff --git a/src/ops/add_bias_residual_layer_norm.cu b/src/ops/add_bias_residual_layer_norm.cu index ceb1a6514e..e503c2603d 100644 --- a/src/ops/add_bias_residual_layer_norm.cu +++ b/src/ops/add_bias_residual_layer_norm.cu @@ -107,13 +107,13 @@ __global__ void LayerNormFusedForwardKernel(int64_t N, T *Y) { __shared__ float m_shared[C10_WARP_SIZE]; __shared__ float v_shared[C10_WARP_SIZE]; - const int64_t i = blockIdx.x; + int64_t const i = blockIdx.x; float sum1 = 0.0f; float sum2 = 0.0f; for (int64_t j = threadIdx.x; j < N; j += min(blockDim.x, kCUDABlockReduceNumThreads)) { - const int64_t index = i * N + j; - const int64_t bias_idx = index % attn_bias_dim; + int64_t const index = i * N + j; + int64_t const bias_idx = index % attn_bias_dim; X[index] = input_ptr[index] + attn_bias_ptr[bias_idx] + residual_ptr[index]; sum1 += static_cast(X[index]); sum2 += static_cast(X[index]) * static_cast(X[index]); @@ -136,11 +136,11 @@ __global__ void LayerNormFusedForwardKernel(int64_t N, using T_ACC = T; for (int64_t j = threadIdx.x; j < N; j += min(blockDim.x, kCUDANumThreads)) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; const T_ACC gamma_v = - gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); + gamma == nullptr ? T_ACC(1.0f) : static_cast(gamma[j]); const T_ACC beta_v = - beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + beta == nullptr ? T_ACC(0.0f) : static_cast(beta[j]); Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * static_cast(rstd[i]) * gamma_v + beta_v; @@ -234,6 +234,20 @@ void AddBiasResidualLayerNorm::inference_kernel_wrapper( m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_BF16) { + AddBiasResidualLayerNorm::inference_kernel<__nv_bfloat16>( + m, + attn_bias_dim, + residual_volume, + input.get_bfloat16_ptr(), + attn_bias.get_bfloat16_ptr(), + residual.get_bfloat16_ptr(), + added_output.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } diff --git a/src/ops/arg_topk.cpp b/src/ops/arg_topk.cpp index f431d3d4bf..54166ab8d2 100644 --- a/src/ops/arg_topk.cpp +++ b/src/ops/arg_topk.cpp @@ -515,6 +515,18 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, m->sorted, m->speculative_decoding ? bc : nullptr, stream); + } else if (input.data_type == DT_BF16) { + ArgTopK::forward_kernel(m, + input.get_bfloat16_ptr(), + m->speculative_decoding ? probs.get_float_ptr() + : nullptr, + indices.get_int32_ptr(), + batch_size, + length, + k, + m->sorted, + m->speculative_decoding ? bc : nullptr, + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/arg_topk.cu b/src/ops/arg_topk.cu index 5b7978812c..671e9aeb25 100644 --- a/src/ops/arg_topk.cu +++ b/src/ops/arg_topk.cu @@ -520,7 +520,19 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, m->sorted, m->speculative_decoding ? bc : nullptr, stream); - } else { + } else if (input.data_type == DT_BF16) { + ArgTopK::forward_kernel(m, + input.get_bfloat16_ptr(), + m->speculative_decoding ? probs.get_float_ptr() + : nullptr, + indices.get_int32_ptr(), + batch_size, + length, + k, + m->sorted, + m->speculative_decoding ? bc : nullptr, + stream); + }else { assert(false && "Unsupported data type"); } diff --git a/src/ops/argmax.cpp b/src/ops/argmax.cpp index 8a1cf0b3b0..1c94585ba1 100644 --- a/src/ops/argmax.cpp +++ b/src/ops/argmax.cpp @@ -466,6 +466,16 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, length, batch_size, stream); + } else if (input.data_type == DT_BF16) { + ArgMax::forward_kernel(m, + input.get_bfloat16_ptr(), + indices.get_int32_ptr(), + m->probs, + m->beam_search ? parent.get_int32_ptr() + : nullptr, + length, + batch_size, + stream); } else { assert(false && "Unsupported data type"); } @@ -491,7 +501,7 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler, : OpMeta(handler, op) { DataType data_type = op->data_type; size_t prob_size = batch_size; - assert(data_type == DT_FLOAT || data_type == DT_HALF); + assert(data_type == DT_FLOAT || data_type == DT_HALF || data_type == DT_BF16); size_t total_size = prob_size * sizeof(float); gpu_mem_allocator.create_legion_instance(reserveInst, total_size); probs = gpu_mem_allocator.allocate_instance(prob_size); diff --git a/src/ops/argmax.cu b/src/ops/argmax.cu index 05c84719c1..aebafc2fff 100644 --- a/src/ops/argmax.cu +++ b/src/ops/argmax.cu @@ -123,6 +123,16 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, length, batch_size, stream); + } else if (input.data_type == DT_BF16) { + ArgMax::forward_kernel<__nv_bfloat16>( + m, + input.get_bfloat16_ptr(), + indices.get_int32_ptr(), + m->probs, + m->beam_search ? parent.get_int32_ptr() : nullptr, + length, + batch_size, + stream); } else { assert(false && "Unsupported data type"); } @@ -153,20 +163,27 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler, size_t d_offsets_size = batch_size; size_t prob_size = batch_size; - assert(data_type == DT_FLOAT || data_type == DT_HALF); + assert(data_type == DT_FLOAT || data_type == DT_HALF || data_type == DT_BF16); size_t total_size = d_offsets_size * sizeof(int) + (data_type == DT_FLOAT ? sizeof(cub::KeyValuePair) * batch_size - : sizeof(cub::KeyValuePair) * batch_size) + + : (data_type == DT_HALF + ? sizeof(cub::KeyValuePair) * batch_size + : sizeof(cub::KeyValuePair) * + batch_size)) + prob_size * sizeof(float); gpu_mem_allocator.create_legion_instance(reserveInst, total_size); d_offsets = gpu_mem_allocator.allocate_instance(d_offsets_size); d_out = data_type == DT_FLOAT ? gpu_mem_allocator.allocate_instance_untyped( batch_size * sizeof(cub::KeyValuePair)) - : gpu_mem_allocator.allocate_instance_untyped( - batch_size * sizeof(cub::KeyValuePair)); + : (data_type == DT_HALF + ? gpu_mem_allocator.allocate_instance_untyped( + batch_size * sizeof(cub::KeyValuePair)) + : gpu_mem_allocator.allocate_instance_untyped( + batch_size * + sizeof(cub::KeyValuePair))); probs = gpu_mem_allocator.allocate_instance(prob_size); // init offset int parallelism = total_ele; @@ -197,6 +214,16 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler, d_offsets, d_offsets + 1, stream)); + } else if (data_type == DT_BF16) { + checkCUDA(cub::DeviceSegmentedReduce::ArgMax( + d_temp_storage, + temp_storage_bytes, + input.get_bfloat16_ptr(), + static_cast *>(d_out), + batch_size, + d_offsets, + d_offsets + 1, + stream)); } gpu_mem_allocator.create_legion_instance(reserveInst, temp_storage_bytes); diff --git a/src/ops/element_unary.cpp b/src/ops/element_unary.cpp index e20200420f..06a4d3c36a 100644 --- a/src/ops/element_unary.cpp +++ b/src/ops/element_unary.cpp @@ -314,6 +314,12 @@ template void int64_t *output_ptr, size_t num_elements); +template void ElementUnary::forward_kernel_wrapper( + ElementUnaryMeta const *m, + hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements); + template void ElementUnary::backward_kernel_wrapper(ElementUnaryMeta const *m, float const *input_ptr, diff --git a/src/ops/element_unary.cu b/src/ops/element_unary.cu index c7f5e90f4c..ed4fb15e18 100644 --- a/src/ops/element_unary.cu +++ b/src/ops/element_unary.cu @@ -307,6 +307,12 @@ template void float const *input_ptr, float *output_ptr, size_t num_elements); +template void + ElementUnary::forward_kernel_wrapper<__nv_bfloat16>(ElementUnaryMeta const *m, + __nv_bfloat16 const *input_ptr, + __nv_bfloat16 *output_ptr, + size_t num_elements); + template void ElementUnary::forward_kernel_wrapper(ElementUnaryMeta const *m, double const *input_ptr, @@ -329,6 +335,7 @@ template void float const *output_ptr, float const *output_grad_ptr, size_t num_elements); + template void ElementUnary::backward_kernel_wrapper(ElementUnaryMeta const *m, double const *input_ptr, diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index 3282bc57d9..2daa48c3d0 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -428,6 +428,11 @@ __host__ void FusedOp::forward_task(Task const *task, m, my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr()); + } else if (m->input_type == DT_BF16) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr()); } break; } @@ -815,6 +820,12 @@ __host__ void my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr(), my_input_accessor[0].domain.get_volume()); + } else if (m->data_type == DT_BF16) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr(), + my_input_accessor[0].domain.get_volume()); } else { assert(false && "Unsupported data type in ElementUnary forward"); } @@ -1039,6 +1050,11 @@ __host__ void m, my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr()); + } else if (m->input_type == DT_BF16) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr()); } break; } diff --git a/src/ops/fused.cu b/src/ops/fused.cu index c6ba0b04c5..16b9a2b782 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -442,6 +442,11 @@ __host__ void FusedOp::forward_task(Task const *task, m, my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr()); + } else if (m->input_type == DT_BF16) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr()); } break; } @@ -850,6 +855,12 @@ __host__ void my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr(), my_input_accessor[0].domain.get_volume()); + } else if (m->data_type == DT_BF16) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr(), + my_input_accessor[0].domain.get_volume()); } else { assert(false && "Unsupported data type in ElementUnary forward"); } @@ -1076,6 +1087,11 @@ __host__ void m, my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr()); + } else if (m->input_type == DT_BF16) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr()); } break; } diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index d60386f927..ca83e7196a 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -802,6 +802,23 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_BF16) { + if (m->offload) { + pre_build_weight_kernel(m, weight, input.data_type, stream); + } + hip_bfloat16 const *bias_ptr = + use_bias ? bias.get_bfloat16_ptr() + : static_cast(nullptr); + Kernels::IncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + m->offload ? static_cast(m->weight_ptr) + : weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } @@ -1098,4 +1115,11 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( DataType data_type, hipStream_t stream); +template void + Kernels::IncMultiHeadAttention::pre_build_weight_kernel( + IncMultiHeadSelfAttentionMeta const *m, + GenericTensorAccessorR const weight, + DataType data_type, + cudaStream_t stream); + }; // namespace FlexFlow diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 42933cee27..9b4cffad4c 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -506,12 +506,15 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; if (m->output_type[0] == DT_FLOAT) { compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } else if (m->output_type[0] == DT_BF16) { + compute_type = CUBLAS_COMPUTE_32F; } #endif // Step 1: Compute QKV projections { - DT alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; // after transpositions int m_q = m->qProjSize * m->num_q_heads; int m_k = m->kProjSize * m->num_q_heads; @@ -637,10 +640,14 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; #else cudaDataType_t compute_type = cublas_data_type; + if (m->output_type[0] == DT_BF16) { + compute_type = CUDA_R_32F; + } #endif // Project to output, save result directly on output tensor { - DT alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; // after transpositions int m_ = m->oProjSize; int k = m->vProjSize * m->num_q_heads; @@ -792,6 +799,12 @@ void pre_build_weight_kernel(IncMultiHeadSelfAttentionMeta const *m, m->weightSize, cudaMemcpyHostToDevice, stream); + } else if (data_type == DT_BF16) { + cudaMemcpyAsync(m->weight_ptr, + weight.get_bfloat16_ptr(), + m->weightSize, + cudaMemcpyHostToDevice, + stream); } else { assert(false); } @@ -914,6 +927,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; if (m->output_type[0] == DT_FLOAT) { compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } else if (m->output_type[0] == DT_BF16) { + compute_type = CUBLAS_COMPUTE_32F; } #endif // int num_requests = bc->num_active_requests(); @@ -938,7 +953,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, // Step 1: compute query-key product QK.T/sqrt(d_k) { // Scale by sqrt(d_k) as per the original attention paper - DT alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; if (*m->qk_prod_scaling) { alpha = static_cast
(1.0f / sqrt(m->kProjSize)); } @@ -1069,7 +1085,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ // softmax(QK.T/sqrt(d_k)).T { - DT alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; // after transpositions int m_ = m->vProjSize; int n = num_new_tokens; @@ -1185,6 +1202,24 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_BF16) { + if (m->offload) { + pre_build_weight_kernel<__nv_bfloat16>( + m, weight, input.data_type, stream); + } + __nv_bfloat16 const *bias_ptr = + use_bias ? bias.get_bfloat16_ptr() + : static_cast<__nv_bfloat16 const *>(nullptr); + Kernels::IncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + m->offload ? static_cast<__nv_bfloat16 *>(m->weight_ptr) + : weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } @@ -1493,6 +1528,13 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( DataType data_type, cudaStream_t stream); +template void + Kernels::IncMultiHeadAttention::pre_build_weight_kernel<__nv_bfloat16>( + IncMultiHeadSelfAttentionMeta const *m, + GenericTensorAccessorR const weight, + DataType data_type, + cudaStream_t stream); + template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -1513,6 +1555,17 @@ template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( int num_tokens, cudaStream_t stream); +template void + Kernels::IncMultiHeadAttention::compute_o_prod_bias<__nv_bfloat16>( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + __nv_bfloat16 *output_ptr, + __nv_bfloat16 const *weight_ptr, + __nv_bfloat16 const *bias_ptr, + int num_tokens, + cudaStream_t stream); + template void Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( IncMultiHeadSelfAttentionMeta const *m, @@ -1526,4 +1579,11 @@ template void BatchConfig const *bc, half *output_ptr, cudaStream_t stream); + +template void + Kernels::IncMultiHeadAttention::compute_attention_kernel_generation< + __nv_bfloat16>(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + __nv_bfloat16 *output_ptr, + cudaStream_t stream); }; // namespace FlexFlow diff --git a/src/ops/kernels/decompress_kernels.cpp b/src/ops/kernels/decompress_kernels.cpp index 22bf93d449..5127e0842e 100644 --- a/src/ops/kernels/decompress_kernels.cpp +++ b/src/ops/kernels/decompress_kernels.cpp @@ -54,10 +54,20 @@ template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void + decompress_int4_general_weights(char const *input_weight_ptr, + hip_bfloat16 *weight_ptr, + int in_dim, + int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void + decompress_int8_general_weights(char const *input_weight_ptr, + hip_bfloat16 *weight_ptr, + int in_dim, + int valueSize); template __global__ void decompress_int4_attention_weights(char *input_weight_ptr, float *weight_ptr, @@ -71,7 +81,12 @@ template __global__ void int qProjSize, int qSize, int num_heads); - +template __global__ void + decompress_int4_attention_weights(char *input_weight_ptr, + hip_bfloat16 *weight_ptr, + int qProjSize, + int qSize, + int num_heads); template __global__ void decompress_int8_attention_weights(char *input_weight_ptr, float *weight_ptr, @@ -86,5 +101,11 @@ template __global__ void int qSize, int num_heads); +template __global__ void + decompress_int8_attention_weights(char *input_weight_ptr, + hip_bfloat16 *weight_ptr, + int qProjSize, + int qSize, + int num_heads); } // namespace Kernels }; // namespace FlexFlow \ No newline at end of file diff --git a/src/ops/kernels/decompress_kernels.cu b/src/ops/kernels/decompress_kernels.cu index 2e02ce1eec..2165967477 100644 --- a/src/ops/kernels/decompress_kernels.cu +++ b/src/ops/kernels/decompress_kernels.cu @@ -209,10 +209,14 @@ template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void decompress_int4_general_weights<__nv_bfloat16>( + char const *input_weight_ptr, __nv_bfloat16 *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void decompress_int8_general_weights<__nv_bfloat16>( + char const *input_weight_ptr, __nv_bfloat16 *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int4_attention_weights(char *input_weight_ptr, float *weight_ptr, @@ -226,6 +230,12 @@ template __global__ void int qProjSize, int qSize, int num_heads); +template __global__ void + decompress_int4_attention_weights<__nv_bfloat16>(char *input_weight_ptr, + __nv_bfloat16 *weight_ptr, + int qProjSize, + int qSize, + int num_heads); template __global__ void decompress_int8_attention_weights(char *input_weight_ptr, @@ -240,6 +250,13 @@ template __global__ void int qProjSize, int qSize, int num_heads); + +template __global__ void + decompress_int8_attention_weights<__nv_bfloat16>(char *input_weight_ptr, + __nv_bfloat16 *weight_ptr, + int qProjSize, + int qSize, + int num_heads); // template // void decompress_weight_bias(T1 *input_weight_ptr, // T2 *weight_ptr, diff --git a/src/ops/kernels/element_binary_kernels.cpp b/src/ops/kernels/element_binary_kernels.cpp index a65372de85..a353ab76ea 100644 --- a/src/ops/kernels/element_binary_kernels.cpp +++ b/src/ops/kernels/element_binary_kernels.cpp @@ -82,8 +82,24 @@ void forward_kernel_wrapper(ElementBinaryMeta const *m, } // print_tensor(in1_ptr, in1_domain.get_volume(), "input1:"); // print_tensor(in2_ptr, in2_domain.get_volume(), "input2:"); - Internal::forward_kernel( - m, in1.get_float_ptr(), in2.get_float_ptr(), out.get_float_ptr(), stream); + if (out.data_type == DT_HALF) { + Internal::forward_kernel( + m, in1.get_half_ptr(), in2.get_half_ptr(), out.get_half_ptr(), stream); + } else if (out.data_type == DT_FLOAT) { + Internal::forward_kernel(m, + in1.get_float_ptr(), + in2.get_float_ptr(), + out.get_float_ptr(), + stream); + } else if (out.data_type == DT_BF16) { + Internal::forward_kernel(m, + in1.get_bfloat16_ptr(), + in2.get_bfloat16_ptr(), + out.get_bfloat16_ptr(), + stream); + } else { + assert(false && "Unsupported data type"); + } // print_tensor(out_ptr, in1_domain.get_volume(), "output:"); if (m->profiling) { checkCUDA(hipEventRecord(t_end, stream)); diff --git a/src/ops/kernels/element_binary_kernels.cu b/src/ops/kernels/element_binary_kernels.cu index 42b31a664a..a762f17774 100644 --- a/src/ops/kernels/element_binary_kernels.cu +++ b/src/ops/kernels/element_binary_kernels.cu @@ -105,6 +105,9 @@ void forward_kernel_wrapper(ElementBinaryMeta const *m, in2.get_float_ptr(), out.get_float_ptr(), stream); + } else if (out.data_type == DT_BF16) { + Internal::forward_kernel( + m, in1.get_bfloat16_ptr(), in2.get_bfloat16_ptr(), out.get_bfloat16_ptr(), stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/kernels/embedding_kernels.cpp b/src/ops/kernels/embedding_kernels.cpp index ee4a6fcea1..7612abd0f0 100644 --- a/src/ops/kernels/embedding_kernels.cpp +++ b/src/ops/kernels/embedding_kernels.cpp @@ -60,7 +60,7 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); - } else if (weight.data_type == DT_HALF) { + } else if (weight.data_type == DT_DOUBLE) { Internal::forward_kernel(input.get_int32_ptr(), output.get_double_ptr(), weight.get_double_ptr(), @@ -70,6 +70,16 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (weight.data_type == DT_BF16) { + Internal::forward_kernel(input.get_int32_ptr(), + output.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -104,6 +114,16 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (weight.data_type == DT_BF16) { + Internal::forward_kernel(input.get_int64_ptr(), + output.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -162,6 +182,16 @@ void backward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (m->output_type[0] == DT_BF16) { + Internal::backward_kernel(input.get_int32_ptr(), + output.get_bfloat16_ptr(), + weight_grad.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -196,6 +226,16 @@ void backward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (m->output_type[0] == DT_BF16) { + Internal::backward_kernel(input.get_int64_ptr(), + output.get_bfloat16_ptr(), + weight_grad.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -332,6 +372,50 @@ __global__ void embed_backward_no_aggr(int64_t const *input, } } +template <> +__global__ void + embed_backward_no_aggr(int const *input, + hip_bfloat16 const *output, + hip_bfloat16 *embed, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + int wordIdx = input[idx]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, output[i]); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += output[i]; +#endif + } +} + +template <> +__global__ void + embed_backward_no_aggr(int64_t const *input, + hip_bfloat16 const *output, + hip_bfloat16 *embed, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + int64_t wordIdx = input[idx]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, output[i]); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += output[i]; +#endif + } +} + template __global__ void embed_backward_with_aggr(TI const *input, TD const *output, @@ -426,6 +510,74 @@ __global__ void embed_backward_with_aggr(int64_t const *input, } } +template <> +__global__ void + embed_backward_with_aggr(int const *input, + hip_bfloat16 const *output, + hip_bfloat16 *embed, + int out_dim, + int in_dim, + int batch_size, + AggrMode aggr) { + hip_bfloat16 scale = 1.0f / in_dim; + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + hip_bfloat16 gradient; + if (aggr == AGGR_MODE_SUM) { + gradient = output[i]; + } else { + assert(aggr == AGGR_MODE_AVG); + gradient = output[i] * scale; + } + for (int j = 0; j < in_dim; j++) { + int wordIdx = input[idx * in_dim + j]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, gradient); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += gradient; +#endif + } + } +} + +template <> +__global__ void + embed_backward_with_aggr(int64_t const *input, + hip_bfloat16 const *output, + hip_bfloat16 *embed, + int out_dim, + int in_dim, + int batch_size, + AggrMode aggr) { + hip_bfloat16 scale = 1.0f / in_dim; + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + hip_bfloat16 gradient; + if (aggr == AGGR_MODE_SUM) { + gradient = output[i]; + } else { + assert(aggr == AGGR_MODE_AVG); + gradient = output[i] * scale; + } + for (int j = 0; j < in_dim; j++) { + int64_t wordIdx = input[idx * in_dim + j]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, gradient); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += gradient; +#endif + } + } +} + /*static*/ template void forward_kernel(TI const *input_ptr, diff --git a/src/ops/kernels/embedding_kernels.cu b/src/ops/kernels/embedding_kernels.cu index 22d8161ff1..911aa5c45d 100644 --- a/src/ops/kernels/embedding_kernels.cu +++ b/src/ops/kernels/embedding_kernels.cu @@ -70,6 +70,16 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (weight.data_type == DT_BF16) { + Internal::forward_kernel(input.get_int32_ptr(), + output.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -104,6 +114,16 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (weight.data_type == DT_BF16) { + Internal::forward_kernel(input.get_int64_ptr(), + output.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -161,6 +181,16 @@ void backward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (m->output_type[0] == DT_BF16) { + Internal::backward_kernel(input.get_int32_ptr(), + output.get_bfloat16_ptr(), + weight_grad.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -195,6 +225,16 @@ void backward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (m->output_type[0] == DT_BF16) { + Internal::backward_kernel(input.get_int64_ptr(), + output.get_bfloat16_ptr(), + weight_grad.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -322,6 +362,50 @@ __global__ void embed_backward_no_aggr(int64_t const *input, } } +template <> +__global__ void + embed_backward_no_aggr(int const *input, + __nv_bfloat16 const *output, + __nv_bfloat16 *embed, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + int wordIdx = input[idx]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, output[i]); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += output[i]; +#endif + } +} + +template <> +__global__ void + embed_backward_no_aggr(int64_t const *input, + __nv_bfloat16 const *output, + __nv_bfloat16 *embed, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + int64_t wordIdx = input[idx]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, output[i]); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += output[i]; +#endif + } +} + template __global__ void embed_backward_with_aggr(TI const *input, TD const *output, @@ -416,6 +500,74 @@ __global__ void embed_backward_with_aggr(int64_t const *input, } } +template <> +__global__ void + embed_backward_with_aggr(int const *input, + __nv_bfloat16 const *output, + __nv_bfloat16 *embed, + int out_dim, + int in_dim, + int batch_size, + AggrMode aggr) { + __nv_bfloat16 scale = 1.0f / in_dim; + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + __nv_bfloat16 gradient; + if (aggr == AGGR_MODE_SUM) { + gradient = output[i]; + } else { + assert(aggr == AGGR_MODE_AVG); + gradient = output[i] * scale; + } + for (int j = 0; j < in_dim; j++) { + int wordIdx = input[idx * in_dim + j]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, gradient); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += gradient; +#endif + } + } +} + +template <> +__global__ void embed_backward_with_aggr( + int64_t const *input, + __nv_bfloat16 const *output, + __nv_bfloat16 *embed, + int out_dim, + int in_dim, + int batch_size, + AggrMode aggr) { + __nv_bfloat16 scale = 1.0f / in_dim; + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + __nv_bfloat16 gradient; + if (aggr == AGGR_MODE_SUM) { + gradient = output[i]; + } else { + assert(aggr == AGGR_MODE_AVG); + gradient = output[i] * scale; + } + for (int j = 0; j < in_dim; j++) { + int64_t wordIdx = input[idx * in_dim + j]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, gradient); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += gradient; +#endif + } + } +} + /*static*/ template void forward_kernel(TI const *input_ptr, diff --git a/src/ops/kernels/linear_kernels.cpp b/src/ops/kernels/linear_kernels.cpp index 072eb5e96b..1004fe2767 100644 --- a/src/ops/kernels/linear_kernels.cpp +++ b/src/ops/kernels/linear_kernels.cpp @@ -124,6 +124,16 @@ void forward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); + } else if (m->input_type[0] == DT_BF16) { + Internal::forward_kernel(m, + input_ptr, + output_ptr, + weight_ptr, + bias_ptr, + in_dim, + out_dim, + batch_size, + stream); } if (m->profiling) { @@ -189,6 +199,19 @@ void backward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); + } else if (m->input_type[0] == DT_BF16) { + Internal::backward_kernel(m, + input_ptr, + input_grad_ptr, + output_ptr, + output_grad_ptr, + kernel_ptr, + kernel_grad_ptr, + bias_grad_ptr, + in_dim, + out_dim, + batch_size, + stream); } if (m->profiling) { diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index c30c9f71c1..77ed3b9a9e 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -58,6 +58,12 @@ LinearMeta::LinearMeta(FFHandler handler, min(CUDA_NUM_THREADS, parallelism), 0, stream>>>((half *)one_ptr, batch_size); + } else if (data_type == DT_BF16) { + Kernels::Linear::Internal:: + build_one_ptr<<>>((__nv_bfloat16 *)one_ptr, batch_size); } // Allocate descriptors @@ -152,6 +158,16 @@ void forward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); + } else if (m->input_type[0] == DT_BF16) { + Internal::forward_kernel<__nv_bfloat16>(m, + input_ptr, + output_ptr, + weight_ptr, + bias_ptr, + in_dim, + out_dim, + batch_size, + stream); } if (m->profiling) { @@ -216,6 +232,19 @@ void backward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); + } else if (m->input_type[0] == DT_BF16) { + Internal::backward_kernel<__nv_bfloat16>(m, + input_ptr, + input_grad_ptr, + output_ptr, + output_grad_ptr, + kernel_ptr, + kernel_grad_ptr, + bias_grad_ptr, + in_dim, + out_dim, + batch_size, + stream); } if (m->profiling) { @@ -316,7 +345,8 @@ void forward_kernel(LinearMeta const *m, } checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - DT alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]); cudaDataType_t weight_type = m->offload ? ff_to_cuda_datatype(m->weight_ptr_type) @@ -332,6 +362,8 @@ void forward_kernel(LinearMeta const *m, cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; if (m->output_type[0] == DT_FLOAT) { compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } else if (m->output_type[0] == DT_BF16) { + compute_type = CUBLAS_COMPUTE_32F; } #endif checkCUDA(cublasGemmEx(m->handle.blas, diff --git a/src/ops/kernels/residual_rms_norm_kernels.cu b/src/ops/kernels/residual_rms_norm_kernels.cu index 17ac14449b..46fc43fe53 100644 --- a/src/ops/kernels/residual_rms_norm_kernels.cu +++ b/src/ops/kernels/residual_rms_norm_kernels.cu @@ -132,7 +132,7 @@ __global__ void ResidualRMSNormFusedForwardKernel(int64_t N, using T_ACC = T; for (int64_t j = threadIdx.x; j < N; j += min(blockDim.x, kCUDANumThreads)) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; Y[index] = static_cast(X_out[index]) * static_cast(rms[i]); output[index] = Y[index] * weights[index % N]; } @@ -204,6 +204,14 @@ void forward_kernel_wrapper(ResidualRMSNormMeta const *m, residual_output.get_float_ptr(), output.get_float_ptr(), stream); + } else if (output.data_type == DT_BF16) { + forward_kernel(m, + input1.get_bfloat16_ptr(), + input2.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + residual_output.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/kernels/rms_norm_kernels.cpp b/src/ops/kernels/rms_norm_kernels.cpp index 24ab7051e6..45ff40f0d0 100644 --- a/src/ops/kernels/rms_norm_kernels.cpp +++ b/src/ops/kernels/rms_norm_kernels.cpp @@ -190,6 +190,12 @@ void forward_kernel_wrapper(RMSNormMeta const *m, weight.get_float_ptr(), output.get_float_ptr(), stream); + } else if (output.data_type == DT_BF16) { + forward_kernel(m, + input.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/kernels/rms_norm_kernels.cu b/src/ops/kernels/rms_norm_kernels.cu index 7c9f4a9f98..115f290f14 100644 --- a/src/ops/kernels/rms_norm_kernels.cu +++ b/src/ops/kernels/rms_norm_kernels.cu @@ -139,7 +139,7 @@ __global__ void NormKernel(int64_t N, T const *X, T const *rstd, T *Y) { using T_ACC = T; const int64_t i = blockIdx.x; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; Y[index] = static_cast(X[index]) * static_cast(rstd[i]); } } @@ -186,7 +186,7 @@ __global__ void RMSNormFusedForwardKernel(int64_t N, using T_ACC = T; for (int64_t j = threadIdx.x; j < N; j += min(blockDim.x, kCUDANumThreads)) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; Y[index] = static_cast(X[index]) * static_cast(rms[i]); output[index] = Y[index] * weights[index % N]; } @@ -246,6 +246,12 @@ void forward_kernel_wrapper(RMSNormMeta const *m, weight.get_float_ptr(), output.get_float_ptr(), stream); + } else if (output.data_type == DT_BF16) { + forward_kernel(m, + input.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/kernels/softmax.cpp b/src/ops/kernels/softmax.cpp index 89c9f14a01..e785d404d7 100644 --- a/src/ops/kernels/softmax.cpp +++ b/src/ops/kernels/softmax.cpp @@ -107,7 +107,10 @@ template void forward_kernel_wrapper(SoftmaxMeta const *m, template void forward_kernel_wrapper(SoftmaxMeta const *m, half const *input_ptr, half *output_ptr); - +template void + forward_kernel_wrapper(SoftmaxMeta const *m, + hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr); template void backward_kernel_wrapper(SoftmaxMeta const *m, float *input_grad_ptr, float const *output_grad_ptr, @@ -116,7 +119,11 @@ template void backward_kernel_wrapper(SoftmaxMeta const *m, half *input_grad_ptr, half const *output_grad_ptr, size_t num_elements); - +template void + backward_kernel_wrapper(SoftmaxMeta const *m, + hip_bfloat16 *input_grad_ptr, + hip_bfloat16 const *output_grad_ptr, + size_t num_elements); namespace Internal { template void forward_kernel(SoftmaxMeta const *m, diff --git a/src/ops/kernels/softmax.cu b/src/ops/kernels/softmax.cu index e47006cc9d..2418f457dd 100644 --- a/src/ops/kernels/softmax.cu +++ b/src/ops/kernels/softmax.cu @@ -105,7 +105,10 @@ template void forward_kernel_wrapper(SoftmaxMeta const *m, template void forward_kernel_wrapper(SoftmaxMeta const *m, half const *input_ptr, half *output_ptr); - +template void + forward_kernel_wrapper<__nv_bfloat16>(SoftmaxMeta const *m, + __nv_bfloat16 const *input_ptr, + __nv_bfloat16 *output_ptr); template void backward_kernel_wrapper(SoftmaxMeta const *m, float *input_grad_ptr, float const *output_grad_ptr, @@ -114,6 +117,12 @@ template void backward_kernel_wrapper(SoftmaxMeta const *m, half *input_grad_ptr, half const *output_grad_ptr, size_t num_elements); + +template void + backward_kernel_wrapper<__nv_bfloat16>(SoftmaxMeta const *m, + __nv_bfloat16 *input_grad_ptr, + __nv_bfloat16 const *output_grad_ptr, + size_t num_elements); namespace Internal { template void forward_kernel(SoftmaxMeta const *m, diff --git a/src/ops/layer_norm.cpp b/src/ops/layer_norm.cpp index 07dbdb3dfb..f675899fcb 100644 --- a/src/ops/layer_norm.cpp +++ b/src/ops/layer_norm.cpp @@ -182,6 +182,15 @@ void LayerNorm::forward_kernel_wrapper(LayerNormMeta const *m, gamma.get_half_ptr(), m->use_bias ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_BF16) { + LayerNorm::forward_kernel( + m, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } diff --git a/src/ops/layer_norm.cu b/src/ops/layer_norm.cu index 44979c48fe..d6a67682e9 100644 --- a/src/ops/layer_norm.cu +++ b/src/ops/layer_norm.cu @@ -273,7 +273,15 @@ void LayerNorm::forward_kernel_wrapper(LayerNormMeta const *m, m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); - } else { + } else if (m->input_type[0] == DT_BF16) { + LayerNorm::forward_kernel<__nv_bfloat16>( + m, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() : nullptr, + stream); + }else { assert(false && "unsupport datatype in layernorm"); } diff --git a/src/ops/linear.cc b/src/ops/linear.cc index 6ca6038778..29e06d3948 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -33,7 +33,7 @@ using namespace FlexFlow::Kernels::Linear; static constexpr int KERNEL_IDX = 0; static constexpr int BIAS_IDX = 1; -Tensor FFModel::dense(const Tensor input, +Tensor FFModel::dense(Tensor const input, int outDim, ActiMode activation, bool use_bias, @@ -175,7 +175,7 @@ Linear::Linear(FFModel &model, Linear::Linear(FFModel &model, LinearParams const ¶ms, - ParallelTensor const input, + const ParallelTensor input, char const *name, bool allocate_weights) : Linear(model, @@ -194,7 +194,7 @@ Linear::Linear(FFModel &model, Linear::Linear(FFModel &model, LayerID const &_layer_guid, - const ParallelTensor _input, + ParallelTensor const _input, int out_dim, ActiMode _activation, RegularizerMode _kernel_reg_type, @@ -437,6 +437,14 @@ OpMeta *Linear::init_task(Task const *task, return init_task_with_dim( \ task, regions, ctx, runtime); \ } \ + } else if (output.data_type == DT_BF16) { \ + if (linear->quantization_type != DT_NONE) { \ + return init_task_with_dim<__ff_bfloat16, char, DIM>( \ + task, regions, ctx, runtime); \ + } else { \ + return init_task_with_dim<__ff_bfloat16, __ff_bfloat16, DIM>( \ + task, regions, ctx, runtime); \ + } \ } else { \ assert(false && "Unsupported data type"); \ } @@ -704,6 +712,14 @@ void Linear::forward_task(Task const *task, return forward_task_with_dim( \ task, regions, ctx, runtime); \ } \ + } else if (m->output_type[0] == DT_BF16) { \ + if (m->quantization_type != DT_NONE) { \ + return forward_task_with_dim<__ff_bfloat16, char, DIM>( \ + task, regions, ctx, runtime); \ + } else { \ + return forward_task_with_dim<__ff_bfloat16, __ff_bfloat16, DIM>( \ + task, regions, ctx, runtime); \ + } \ } else { \ assert(false && "Unsupported data type"); \ } @@ -859,6 +875,9 @@ void Linear::backward_task(Task const *task, return backward_task_with_dim(task, regions, ctx, runtime); \ } else if (m->output_type[0] == DT_FLOAT) { \ return backward_task_with_dim(task, regions, ctx, runtime); \ + } else if (m->output_type[0] == DT_BF16) { \ + return backward_task_with_dim<__ff_bfloat16, DIM>( \ + task, regions, ctx, runtime); \ } else { \ assert(false && "Unsupported data type"); \ } diff --git a/src/ops/residual_layer_norm.cpp b/src/ops/residual_layer_norm.cpp index f1b7a537b0..290d8cb3d9 100644 --- a/src/ops/residual_layer_norm.cpp +++ b/src/ops/residual_layer_norm.cpp @@ -230,6 +230,18 @@ void ResidualLayerNorm::inference_kernel_wrapper( m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_BF16) { + ResidualLayerNorm::inference_kernel( + m, + input.get_bfloat16_ptr(), + residual1.get_bfloat16_ptr(), + m->use_two_residuals ? residual2.get_bfloat16_ptr() : nullptr, + added_output.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } diff --git a/src/ops/residual_layer_norm.cu b/src/ops/residual_layer_norm.cu index e5ebdce6ed..2f310c3c2c 100644 --- a/src/ops/residual_layer_norm.cu +++ b/src/ops/residual_layer_norm.cu @@ -225,6 +225,18 @@ void ResidualLayerNorm::inference_kernel_wrapper( m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_BF16) { + ResidualLayerNorm::inference_kernel<__nv_bfloat16>( + m, + input.get_bfloat16_ptr(), + residual1.get_bfloat16_ptr(), + m->use_two_residuals ? residual2.get_bfloat16_ptr() : nullptr, + added_output.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index 461d72ec71..96438fa93a 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -182,6 +182,14 @@ void Sampling::forward_kernel_wrapper(SamplingMeta const *m, length, batch_size, stream); + } else if (input.data_type == DT_BF16) { + Sampling::forward_kernel<__nv_bfloat16>(m, + input.get_bfloat16_ptr(), + indices.get_int32_ptr(), + m->top_p, + length, + batch_size, + stream); } else { assert(false && "Unsupported data type"); } @@ -270,7 +278,22 @@ SamplingMeta::SamplingMeta(FFHandler handler, 0, // begin_bit data_type_size(data_type) * 8, // end_bit = sizeof(KeyT) * 8 stream)); - } else { + } else if (data_type == DT_BF16) { + checkCUDA(cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, + temp_storage_bytes, + input.get_bfloat16_ptr(), + input.get_bfloat16_ptr(), + idx, + idx, + total_ele, + batch_size, + begin_offset, + end_offset + 1, + 0, // begin_bit + data_type_size(data_type) * 8, // end_bit = sizeof(KeyT) * 8 + stream)); + }else { assert(false && "input type in float and half"); } diff --git a/src/ops/sigmoid_silu_multi.cpp b/src/ops/sigmoid_silu_multi.cpp index 7b7f30a288..b45625d204 100644 --- a/src/ops/sigmoid_silu_multi.cpp +++ b/src/ops/sigmoid_silu_multi.cpp @@ -101,6 +101,14 @@ void SigmoidSiluMulti::inference_kernel_wrapper( input1.get_half_ptr(), input2.get_half_ptr(), output.get_half_ptr()); + } else if (m->input_type[0] == DT_BF16) { + SigmoidSiluMultiKernel<<>>(input1.domain.get_volume(), + input1.get_bfloat16_ptr(), + input2.get_bfloat16_ptr(), + output.get_bfloat16_ptr()); } else { assert(false && "unsupport datatype in SigmoidSiluMulti"); } diff --git a/src/ops/sigmoid_silu_multi.cu b/src/ops/sigmoid_silu_multi.cu index 590b641b5a..05bf8b9be8 100644 --- a/src/ops/sigmoid_silu_multi.cu +++ b/src/ops/sigmoid_silu_multi.cu @@ -80,6 +80,14 @@ void SigmoidSiluMulti::inference_kernel_wrapper( input1.get_half_ptr(), input2.get_half_ptr(), output.get_half_ptr()); + } else if (m->input_type[0] == DT_BF16) { + SigmoidSiluMultiKernel<<>>(input1.domain.get_volume(), + input1.get_bfloat16_ptr(), + input2.get_bfloat16_ptr(), + output.get_bfloat16_ptr()); } else { assert(false && "unsupport datatype in SigmoidSiluMulti"); } diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index ba0a1288d6..0d0cceb5c4 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -319,6 +319,9 @@ void Softmax::forward_task(Task const *task, forward_kernel_wrapper(m, input.get_half_ptr(), output.get_half_ptr()); } else if (m->output_type == DT_FLOAT) { forward_kernel_wrapper(m, input.get_float_ptr(), output.get_float_ptr()); + } else if (m->output_type == DT_BF16) { + forward_kernel_wrapper( + m, input.get_bfloat16_ptr(), output.get_bfloat16_ptr()); } else { assert(false && "Unsupported data type"); } @@ -366,6 +369,9 @@ void Softmax::backward_task(Task const *task, return backward_task_with_dim(task, regions, ctx, runtime); \ } else if (m->output_type == DT_FLOAT) { \ return backward_task_with_dim(task, regions, ctx, runtime); \ + } else if (m->output_type == DT_BF16) { \ + return backward_task_with_dim<__ff_bfloat16, DIM>( \ + task, regions, ctx, runtime); \ } else { \ assert(false && "Unsupported data type"); \ } @@ -429,6 +435,9 @@ void Softmax::inference_task(Task const *task, forward_kernel_wrapper(m, input.get_half_ptr(), output.get_half_ptr()); } else if (m->output_type == DT_FLOAT) { forward_kernel_wrapper(m, input.get_float_ptr(), output.get_float_ptr()); + } else if (m->output_type == DT_BF16) { + forward_kernel_wrapper( + m, input.get_bfloat16_ptr(), output.get_bfloat16_ptr()); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/spec_inc_multihead_self_attention.cpp b/src/ops/spec_inc_multihead_self_attention.cpp index b1687d12a2..ca0d607ecd 100644 --- a/src/ops/spec_inc_multihead_self_attention.cpp +++ b/src/ops/spec_inc_multihead_self_attention.cpp @@ -562,6 +562,19 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_BF16) { + float const *bias_ptr = use_bias + ? bias.get_bfloat16_ptr() + : static_cast(nullptr); + Kernels::SpecIncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 2d80ed2221..0d3e3a1250 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -480,6 +480,8 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; if (m->output_type[0] == DT_FLOAT) { compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } else if (m->output_type[0] == DT_BF16) { + compute_type = CUBLAS_COMPUTE_32F; } #endif // int num_requests = bc->num_active_requests(); @@ -533,7 +535,8 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, int strideC = num_new_tokens * total_tokens; // a flag of using this scaling alpha - DT alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; if (*m->qk_prod_scaling) { alpha = static_cast
(1.0f / sqrt(m->kProjSize)); } @@ -789,6 +792,19 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_BF16) { + __nv_bfloat16 const *bias_ptr = + use_bias ? bias.get_bfloat16_ptr() + : static_cast<__nv_bfloat16 const *>(nullptr); + Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } diff --git a/src/ops/tree_inc_multihead_self_attention.cpp b/src/ops/tree_inc_multihead_self_attention.cpp index 26291fb3b4..754ae24793 100644 --- a/src/ops/tree_inc_multihead_self_attention.cpp +++ b/src/ops/tree_inc_multihead_self_attention.cpp @@ -570,6 +570,24 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_BF16) { + if (m->offload) { + pre_build_weight_kernel(m, weight, input.data_type, stream); + } + + hip_bfloat16 const *bias_ptr = + use_bias ? bias.get_bfloat16_ptr() + : static_cast(nullptr); + Kernels::TreeIncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + m->offload ? static_cast(m->weight_ptr) + : weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index fc86e1498e..2b50e93a02 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -1003,6 +1003,25 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_BF16) { + if (m->offload) { + pre_build_weight_kernel<__nv_bfloat16>( + m, weight, input.data_type, stream); + } + + __nv_bfloat16 const *bias_ptr = + use_bias ? bias.get_bfloat16_ptr() + : static_cast<__nv_bfloat16 const *>(nullptr); + Kernels::TreeIncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + m->offload ? static_cast<__nv_bfloat16 *>(m->weight_ptr) + : weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index 7c266c5392..ddbafc23c9 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -365,6 +365,8 @@ void Combine::forward_task(Task const *task, DataType data_type = m->input_type[0]; if (data_type == DT_HALF) { forward_task_with_type(task, regions, ctx, runtime); + } else if (data_type == DT_BF16) { + forward_task_with_type<__ff_bfloat16>(task, regions, ctx, runtime); } else if (data_type == DT_FLOAT) { forward_task_with_type(task, regions, ctx, runtime); } else if (data_type == DT_DOUBLE) { diff --git a/src/parallel_ops/kernels/combine_kernels.cpp b/src/parallel_ops/kernels/combine_kernels.cpp index d6e9568223..b593a92a7d 100644 --- a/src/parallel_ops/kernels/combine_kernels.cpp +++ b/src/parallel_ops/kernels/combine_kernels.cpp @@ -57,6 +57,9 @@ template void forward_kernel(half const *input_ptr, template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements); +template void forward_kernel(hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements); template void forward_kernel(double const *input_ptr, double *output_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/combine_kernels.cu b/src/parallel_ops/kernels/combine_kernels.cu index 1ab79a7944..ca9451ed64 100644 --- a/src/parallel_ops/kernels/combine_kernels.cu +++ b/src/parallel_ops/kernels/combine_kernels.cu @@ -50,6 +50,11 @@ template void forward_kernel(half const *input_ptr, template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements); + +template void forward_kernel<__nv_bfloat16>(__nv_bfloat16 const *input_ptr, + __nv_bfloat16 *output_ptr, + size_t num_elements); + template void forward_kernel(double const *input_ptr, double *output_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/reduction_kernels.cpp b/src/parallel_ops/kernels/reduction_kernels.cpp index 2a3fe5cca1..ade5d9b402 100644 --- a/src/parallel_ops/kernels/reduction_kernels.cpp +++ b/src/parallel_ops/kernels/reduction_kernels.cpp @@ -78,6 +78,12 @@ template __global__ void reduction_forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements, size_t num_replicas); +template __global__ void + reduction_forward_kernel(hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements, + size_t num_replicas); + template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements, @@ -86,6 +92,10 @@ template void forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements, size_t num_replicas); +template void forward_kernel(hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements, + size_t num_replicas); template void backward_kernel(float const *output_grad_ptr, float *input_grad_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/reduction_kernels.cu b/src/parallel_ops/kernels/reduction_kernels.cu index 34ae8007da..67547c6d17 100644 --- a/src/parallel_ops/kernels/reduction_kernels.cu +++ b/src/parallel_ops/kernels/reduction_kernels.cu @@ -71,6 +71,13 @@ template __global__ void reduction_forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements, size_t num_replicas); + +template __global__ void + reduction_forward_kernel<__nv_bfloat16>(__nv_bfloat16 const *input_ptr, + __nv_bfloat16 *output_ptr, + size_t num_elements, + size_t num_replicas); + template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements, @@ -79,6 +86,12 @@ template void forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements, size_t num_replicas); + +template void forward_kernel<__nv_bfloat16>(__nv_bfloat16 const *input_ptr, + __nv_bfloat16 *output_ptr, + size_t num_elements, + size_t num_replicas); + template void backward_kernel(float const *output_grad_ptr, float *input_grad_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/replicate_kernels.cpp b/src/parallel_ops/kernels/replicate_kernels.cpp index 1647f014be..7d6b2fc63a 100644 --- a/src/parallel_ops/kernels/replicate_kernels.cpp +++ b/src/parallel_ops/kernels/replicate_kernels.cpp @@ -73,6 +73,9 @@ template void forward_kernel(float const *input_ptr, template void forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements); +template void forward_kernel(hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements); template __global__ void replicate_backward_kernel(float const *input_ptr, float *output_ptr, diff --git a/src/parallel_ops/kernels/replicate_kernels.cu b/src/parallel_ops/kernels/replicate_kernels.cu index 35bc109bd3..573de9ae2e 100644 --- a/src/parallel_ops/kernels/replicate_kernels.cu +++ b/src/parallel_ops/kernels/replicate_kernels.cu @@ -66,6 +66,9 @@ template void forward_kernel(float const *input_ptr, template void forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements); +template void forward_kernel<__nv_bfloat16>(__nv_bfloat16 const *input_ptr, + __nv_bfloat16 *output_ptr, + size_t num_elements); template __global__ void replicate_backward_kernel(float const *input_ptr, float *output_ptr, diff --git a/src/parallel_ops/reduction.cc b/src/parallel_ops/reduction.cc index 5dca591328..dd2d4a54c2 100644 --- a/src/parallel_ops/reduction.cc +++ b/src/parallel_ops/reduction.cc @@ -380,6 +380,11 @@ void Reduction::forward_task(Task const *task, output.get_float_ptr(), num_elements, num_replicas); + } else if (input.data_type == DT_BF16) { + forward_kernel<__ff_bfloat16>(input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + num_elements, + num_replicas); } else { assert(false && "Unspported data type"); } diff --git a/src/parallel_ops/replicate.cc b/src/parallel_ops/replicate.cc index 20face74e8..c3821a7d74 100644 --- a/src/parallel_ops/replicate.cc +++ b/src/parallel_ops/replicate.cc @@ -373,6 +373,10 @@ void Replicate::forward_task(Task const *task, forward_kernel(input.get_float_ptr(), output.get_float_ptr(), input_domain.get_volume()); + } else if (input.data_type == DT_BF16) { + forward_kernel<__ff_bfloat16>(input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + input_domain.get_volume()); } else { assert(false && "Unspported data type"); } diff --git a/src/runtime/accessor.cc b/src/runtime/accessor.cc index d3b94bf14a..a4bd0c11a2 100644 --- a/src/runtime/accessor.cc +++ b/src/runtime/accessor.cc @@ -77,6 +77,15 @@ half const *GenericTensorAccessorR::get_half_ptr() const { } } +__ff_bfloat16 const *GenericTensorAccessorR::get_bfloat16_ptr() const { + if (data_type == DT_BF16) { + return static_cast<__ff_bfloat16 const *>(ptr); + } else { + assert(false && "Invalid Accessor Type"); + return static_cast<__ff_bfloat16 const *>(nullptr); + } +} + char const *GenericTensorAccessorR::get_byte_ptr() const { if (data_type == DT_INT4 || data_type == DT_INT8) { return static_cast(ptr); @@ -165,6 +174,15 @@ half *GenericTensorAccessorW::get_half_ptr() const { } } +__ff_bfloat16 *GenericTensorAccessorW::get_bfloat16_ptr() const { + if (data_type == DT_BF16) { + return static_cast<__ff_bfloat16 *>(ptr); + } else { + assert(false && "Invalid Accessor Type"); + return static_cast<__ff_bfloat16 *>(nullptr); + } +} + char *GenericTensorAccessorW::get_byte_ptr() const { if (data_type == DT_INT4 || data_type == DT_INT8) { return static_cast(ptr); @@ -271,6 +289,11 @@ GenericTensorAccessorR ptr = helperGetTensorPointerRO(region, req, fid, ctx, runtime); break; } + case DT_BF16: { + ptr = helperGetTensorPointerRO<__ff_bfloat16>( + region, req, fid, ctx, runtime); + break; + } case DT_FLOAT: { ptr = helperGetTensorPointerRO(region, req, fid, ctx, runtime); break; @@ -317,6 +340,11 @@ GenericTensorAccessorW ptr = helperGetTensorPointerWO(region, req, fid, ctx, runtime); break; } + case DT_BF16: { + ptr = helperGetTensorPointerWO<__ff_bfloat16>( + region, req, fid, ctx, runtime); + break; + } case DT_FLOAT: { ptr = helperGetTensorPointerWO(region, req, fid, ctx, runtime); break; @@ -363,6 +391,11 @@ GenericTensorAccessorW ptr = helperGetTensorPointerRW(region, req, fid, ctx, runtime); break; } + case DT_BF16: { + ptr = helperGetTensorPointerRW<__ff_bfloat16>( + region, req, fid, ctx, runtime); + break; + } case DT_FLOAT: { ptr = helperGetTensorPointerRW(region, req, fid, ctx, runtime); break; diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index fa6bf55fe5..46ab6b42da 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -320,6 +320,33 @@ __host__ void checkCUDA(cudaFreeHost(host_ptr)); } +template <> +__host__ void save_tensor(__nv_bfloat16 const *ptr, + size_t num_elements, + char const *file_name) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + __nv_bfloat16 *host_ptr; + checkCUDA(cudaHostAlloc(&host_ptr, + sizeof(__nv_bfloat16) * num_elements, + cudaHostAllocPortable | cudaHostAllocMapped)); + checkCUDA(cudaMemcpyAsync(host_ptr, + ptr, + sizeof(__nv_bfloat16) * num_elements, + cudaMemcpyDeviceToHost, + stream)); + checkCUDA(cudaDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%.9f, ", (float)host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(cudaFreeHost(host_ptr)); +} + template <> __host__ void save_tensor(int32_t const *ptr, size_t num_elements, @@ -519,6 +546,8 @@ cudnnDataType_t ff_to_cudnn_datatype(DataType type) { switch (type) { case DT_HALF: return CUDNN_DATA_HALF; + case DT_BF16: + return CUDNN_DATA_BFLOAT16; case DT_FLOAT: return CUDNN_DATA_FLOAT; case DT_DOUBLE: @@ -535,6 +564,8 @@ cudaDataType_t ff_to_cuda_datatype(DataType type) { switch (type) { case DT_HALF: return CUDA_R_16F; + case DT_BF16: + return CUDA_R_16BF; case DT_FLOAT: return CUDA_R_32F; case DT_DOUBLE: @@ -552,6 +583,8 @@ ncclDataType_t ff_to_nccl_datatype(DataType type) { switch (type) { case DT_HALF: return ncclHalf; + case DT_BF16: + return ncclBfloat16; case DT_FLOAT: return ncclFloat; case DT_DOUBLE: @@ -595,6 +628,9 @@ cudnnDataType_t cuda_to_cudnn_datatype(cudaDataType_t type) { template __global__ void assign_kernel(half *ptr, coord_t size, half value); +template __global__ void assign_kernel<__nv_bfloat16>(__nv_bfloat16 *ptr, + coord_t size, + __nv_bfloat16 value); template __global__ void assign_kernel(float *ptr, coord_t size, float value); template __global__ void @@ -700,6 +736,10 @@ template __host__ void save_tensor(int64_t const *ptr, template __host__ void save_tensor(half const *ptr, size_t rect, char const *file_name); +template __host__ void save_tensor<__nv_bfloat16>(__nv_bfloat16 const *ptr, + size_t rect, + char const *file_name); + template __host__ float *download_tensor(float const *ptr, size_t num_elements); template __host__ half *download_tensor(half const *ptr, diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index c7b6e1257a..8c130fcfb1 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -213,6 +213,8 @@ size_t data_type_size(DataType type) { switch (type) { case DT_HALF: return sizeof(half); + case DT_BF16: + return sizeof(__ff_bfloat16); case DT_FLOAT: return sizeof(float); case DT_DOUBLE: diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index 56558b3185..6f5864acff 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -127,6 +127,335 @@ void load_attention_weights_multi_query(DT *ptr, } } +///////////////////////bfloat16 function/////////////////////// + +void load_from_file_b16(__ff_bfloat16 *ptr, size_t size, std::string filepath) { + std::ifstream in(filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + std::cout << "Could not open file: " << filepath << std::endl; + } + assert(in.good() && "incorrect weight file path"); + std::vector host_array(size); + size_t loaded_data_size = sizeof(float) * size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) { + std::cout << "load weight data error " << in_get_size << ", " + << loaded_data_size << ", " << sizeof(float) + << ", file path: " << filepath << std::endl; + assert(false); + } + assert(size == host_array.size()); + + // normal + long data_index = 0; + for (auto v : host_array) { + ptr[data_index++] = __float2bfloat16(v); + } + in.close(); +} + +void load_attention_weights_v2_b16(__ff_bfloat16 *ptr, + int num_heads, + int num_kv_heads, + size_t hidden_dim, + size_t qkv_inner_dim, + std::string layer_name, + std::string weights_folder, + size_t volume, + int tensor_parallelism_degree) { + // layers_0_attention_wq_weight + // layers_0_self_attn_q_proj_weight + std::string q_file = layer_name + "_wq_weight"; + std::string k_file = layer_name + "_wk_weight"; + std::string v_file = layer_name + "_wv_weight"; + std::string o_file = layer_name + "_wo_weight"; + std::vector weight_filenames = {q_file, k_file, v_file}; + int file_index = 0; + + int base_index = 0; + size_t single_proj_size = + hidden_dim * + qkv_inner_dim; // size of each of Q,K,V,O weights for a single head + size_t one_weight_file_size = + num_heads * single_proj_size; // size of each of Q/K/V/O for all heads + + size_t q_size = one_weight_file_size, o_size = one_weight_file_size; + size_t k_size = single_proj_size * num_kv_heads, + v_size = single_proj_size * num_kv_heads; + + size_t k_replicate_size = one_weight_file_size; + size_t v_replicate_size = one_weight_file_size; + + int replicate_num = num_heads / num_kv_heads; + + // stride for q, k, v, o + size_t stride_size = (q_size + v_replicate_size + k_replicate_size + o_size) / + tensor_parallelism_degree; + for (auto filename : weight_filenames) { + std::cout << "Loading weight file " << filename << std::endl; + std::string weight_filepath = join_path({weights_folder, filename}); + + int data_index = 0; + size_t partial_size = (file_index == 0 || file_index == 3) + ? one_weight_file_size + : single_proj_size * num_kv_heads; + size_t one_partition_size = + one_weight_file_size / tensor_parallelism_degree; + + std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + std::cout << "Could not open file: " << weight_filepath << std::endl; + } + assert(in.good() && "incorrect weight file path"); + std::vector host_array(partial_size); + size_t loaded_data_size = sizeof(float) * partial_size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + size_t in_get_size = in.gcount(); + + if (in_get_size != loaded_data_size) { + std::cout << "load attention data error " << in_get_size << ", " + << loaded_data_size << ", " << file_index << ", " + << weight_filepath << "\n"; + assert(false && "data size mismatch"); + } + // wq, wk, wo + if (file_index == 0) { + for (int i = 0; i < tensor_parallelism_degree; i++) { + for (int j = 0; j < one_partition_size; j++) { + ptr[base_index + i * stride_size + j] = + __float2bfloat16(host_array.at(data_index++)); + } + } + } else { + for (int i = 0; i < num_heads; i++) { + int kv_idx = i / (num_heads / num_kv_heads); + int head_idx = i % (num_heads / tensor_parallelism_degree); + int tp_idx = (i / (num_heads / tensor_parallelism_degree)); + for (int j = 0; j < single_proj_size; j++) { + ptr[base_index + tp_idx * stride_size + single_proj_size * head_idx + + j] = + __float2bfloat16(host_array.at(kv_idx * single_proj_size + j)); + } + } + } + + // assert(data_index == partial_size); + base_index += one_partition_size; + file_index++; + } + assert(base_index == (q_size + k_replicate_size + v_replicate_size) / + tensor_parallelism_degree); + + { + std::cout << "Loading weight file " << o_file << std::endl; + std::string weight_filepath = join_path({weights_folder, o_file}); + + std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + std::cout << "Could not open file: " << weight_filepath << std::endl; + } + assert(in.good() && "incorrect weight file path"); + std::vector host_array(one_weight_file_size); + size_t loaded_data_size = sizeof(float) * one_weight_file_size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + size_t in_get_size = in.gcount(); + + if (in_get_size != loaded_data_size) { + std::cout << "load data error" << std::endl; + assert(false); + } + assert(one_weight_file_size == host_array.size()); + int data_index = 0; + + int one_partition_size = + qkv_inner_dim * (num_heads / tensor_parallelism_degree); + for (int i = 0; i < one_weight_file_size; i++) { + int part_idx = (i / one_partition_size) % tensor_parallelism_degree; + int block_num = (i / one_partition_size); + int offset = block_num / tensor_parallelism_degree * one_partition_size + + (i % one_partition_size); + ptr[base_index + part_idx * stride_size + offset] = + __float2bfloat16(host_array.at(data_index++)); + } + + in.close(); + + assert(data_index == one_weight_file_size); + } +} + +void load_attention_bias_v2_b16(__ff_bfloat16 *ptr, + int num_heads, + int num_kv_heads, + size_t hidden_dim, + size_t qkv_inner_dim, + bool final_bias, + std::string layer_name, + std::string weights_folder) { + std::string q_file = layer_name + "_wq_bias"; + std::string k_file = layer_name + "_wk_bias"; + std::string v_file = layer_name + "_wv_bias"; + std::vector bias_files = {q_file, k_file, v_file}; + if (final_bias) { + std::string o_file = layer_name + "_wo_bias"; + bias_files.push_back(o_file); + } + + int file_index = 0; + + // now only opt use this. + // assert(num_heads == num_kv_heads); + int idx = 0; + + for (auto filename : bias_files) { + std::cout << "Loading weight file " << filename << std::endl; + std::string weight_filepath = join_path({weights_folder, filename}); + + int n_heads = file_index == 0 ? num_heads : num_kv_heads; + + int replicate_num = num_heads / num_kv_heads; + + size_t qkv_partial_size = qkv_inner_dim * n_heads; + size_t qkv_replicate_size = qkv_inner_dim * num_heads; + size_t out_partial_size = hidden_dim; + size_t partial_size = + (file_index < 3) ? qkv_partial_size : out_partial_size; + std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); + assert(in.good() && "incorrect bias file path"); + std::vector host_array(partial_size); + size_t loaded_data_size = sizeof(float) * partial_size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + size_t in_get_size = in.gcount(); + + if (in_get_size != loaded_data_size) { + printf( + "load bias data error: in_get_size (%lu) != loaded_data_size (%lu)\n", + in_get_size, + loaded_data_size); + assert(false); + } + assert(partial_size == host_array.size()); + + size_t data_index = 0; + + // q, o + if (file_index == 0 || file_index == 3) { + for (int i = 0; i < partial_size; i++) { + ptr[idx + i] = __float2bfloat16(host_array.at(data_index)); + data_index++; + } + } else { + // k, v + for (int i = 0; i < partial_size; i++) { + for (int j = 0; j < replicate_num; j++) { + ptr[idx + j * partial_size + i] = + __float2bfloat16(host_array.at(data_index)); + } + data_index++; + } + } + + file_index++; + idx += qkv_replicate_size; + + in.close(); + } +} + +void FileDataLoader::load_single_weight_tensor_b16(FFModel *ff, + Layer *l, + int weight_idx) { + Tensor weight = l->weights[weight_idx]; + + // Create a buffer to store weight data from the file + size_t volume = 1; + std::vector dims_vec; + for (int i = 0; i < weight->num_dims; i++) { + dims_vec.push_back(weight->dims[i]); + volume *= weight->dims[i]; + } + assert(data_type_size(weight->data_type) == sizeof(__ff_bfloat16)); + __ff_bfloat16 *data = (__ff_bfloat16 *)malloc(sizeof(__ff_bfloat16) * volume); + + std::string weight_filename = removeGuidOperatorName(std::string(l->name)); + + if (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || + l->op_type == OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION || + l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION) { + if (weight_filename.find("self_attention") != std::string::npos) { + load_attention_weights_multi_query( + data, weight_filename, weights_folder, hidden_dim, num_heads); + } else if (weight_filename.find("attention") != std::string::npos && + weight_filename.rfind("attention") == + weight_filename.length() - strlen("attention")) { + if (weight_idx == 0) { + load_attention_weights_v2_b16(data, + num_heads, + num_kv_heads, + hidden_dim, + qkv_inner_dim, + weight_filename, + weights_folder, + volume, + tensor_parallelism_degree); + } else { + long long value; + l->get_int_property("final_bias", value); + bool final_bias = (bool)value; + load_attention_bias_v2_b16(data, + num_heads, + num_kv_heads, + hidden_dim, + qkv_inner_dim, + final_bias, + weight_filename, + weights_folder); + } + + } else { + assert(false); + } + } else if (l->op_type == OP_ADD_BIAS_RESIDUAL_LAYERNORM) { + assert(weight_idx >= 0 || weight_idx <= 2); + weight_filename += (weight_idx == 0) + ? "_attn_bias" + : ((weight_idx == 1) ? "_weight" : "_bias"); + std::cout << "Loading weight file " << weight_filename << std::endl; + std::string weight_filepath = join_path({weights_folder, weight_filename}); + load_from_file_b16(data, volume, weight_filepath); + } else { + // default op + assert(weight_idx == 0 || weight_idx == 1); + // handle exception + if (weight_filename != "embed_tokens_weight_lm_head") { + weight_filename += weight_idx == 0 ? "_weight" : "_bias"; + } + std::cout << "Loading weight file " << weight_filename << std::endl; + std::string weight_filepath = join_path({weights_folder, weight_filename}); + load_from_file_b16(data, volume, weight_filepath); + } + + // Copy the weight data from the buffer to the weight's ParallelTensor + ParallelTensor weight_pt; + ff->get_parallel_tensor_from_tensor(weight, weight_pt); + weight_pt->set_tensor<__ff_bfloat16>(ff, dims_vec, data); + + // Free buffer memory + delete data; +} + +///////////////////////////bdlot16 functions//////////// + template void load_attention_bias_v2(DT *ptr, int num_heads, @@ -356,7 +685,8 @@ void load_from_file(DT *ptr, size_t size, std::string filepath) { size_t in_get_size = in.gcount(); if (in_get_size != loaded_data_size) { std::cout << "load weight data error " << in_get_size << ", " - << loaded_data_size << ", " << sizeof(DT) << std::endl; + << loaded_data_size << ", " << sizeof(DT) + << ", filepath: " << filepath << std::endl; assert(false); } assert(size == host_array.size()); @@ -807,6 +1137,9 @@ void FileDataLoader::load_weights(FFModel *ff) { case DT_FLOAT: load_single_weight_tensor(ff, l, i); break; + case DT_BF16: + load_single_weight_tensor_b16(ff, l, i); + break; case DT_INT4: case DT_INT8: // load weights in quantization diff --git a/src/runtime/hip_helper.cpp b/src/runtime/hip_helper.cpp index fb94135c8f..ab1f9d6ece 100644 --- a/src/runtime/hip_helper.cpp +++ b/src/runtime/hip_helper.cpp @@ -299,6 +299,33 @@ __host__ void checkCUDA(hipHostFree(host_ptr)); } +template <> +__host__ void save_tensor(hip_bfloat16 const *ptr, + size_t num_elements, + char const *file_name) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + hip_bfloat16 *host_ptr; + checkCUDA(hipHostMalloc(&host_ptr, + sizeof(hip_bfloat16) * num_elements, + hipHostMallocPortable | hipHostMallocMapped)); + checkCUDA(hipMemcpyAsync(host_ptr, + ptr, + sizeof(hip_bfloat16) * num_elements, + hipMemcpyDeviceToHost, + stream)); + checkCUDA(hipDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%.9f, ", (float)host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(hipHostFree(host_ptr)); +} + template <> __host__ void save_tensor(int32_t const *ptr, size_t num_elements, @@ -489,6 +516,8 @@ miopenDataType_t ff_to_cudnn_datatype(DataType type) { switch (type) { case DT_HALF: return miopenHalf; + case DT_BF16: + return miopenBFloat16; case DT_FLOAT: return miopenFloat; case DT_DOUBLE: @@ -510,6 +539,10 @@ hipblasDatatype_t ff_to_cuda_datatype(DataType type) { return HIPBLAS_R_64F; case DT_INT32: return HIPBLAS_R_32I; + case DT_BF16: + return HIPBLAS_R_16B; + case DT_HALF: + return HIPBLAS_R_16F; default: assert(false && "Unspoorted cuda data type"); } @@ -520,6 +553,8 @@ ncclDataType_t ff_to_nccl_datatype(DataType type) { switch (type) { case DT_HALF: return ncclHalf; + case DT_BF16: + return ncclBfloat16; case DT_FLOAT: return ncclFloat; case DT_DOUBLE: @@ -540,6 +575,9 @@ void handle_unimplemented_hip_kernel(OperatorType op_type) { template __global__ void assign_kernel(half *ptr, coord_t size, half value); +template __global__ void assign_kernel(hip_bfloat16 *ptr, + coord_t size, + hip_bfloat16 value); template __global__ void assign_kernel(float *ptr, coord_t size, float value); template __global__ void @@ -609,7 +647,9 @@ template __host__ void save_tensor(int64_t const *ptr, char const *file_name); template __host__ void save_tensor(half const *ptr, size_t rect, char const *file_name); - +template __host__ void save_tensor(hip_bfloat16 const *ptr, + size_t rect, + char const *file_name); template __host__ float *download_tensor(float const *ptr, size_t num_elements); template __host__ half *download_tensor(half const *ptr, diff --git a/src/runtime/initializer_kernel.cpp b/src/runtime/initializer_kernel.cpp index 1005d93cec..659c8a4d2c 100644 --- a/src/runtime/initializer_kernel.cpp +++ b/src/runtime/initializer_kernel.cpp @@ -259,6 +259,17 @@ void ZeroInitializer::init_task(Task const *task, w, domain.get_volume(), 0.0f); + } else if (meta->data_types[i] == DT_BF16) { + hip_bfloat16 *w = helperGetTensorPointerWO( + regions[i], task->regions[i], FID_DATA, ctx, runtime); + hipLaunchKernelGGL(HIP_KERNEL_NAME(assign_kernel), + GET_BLOCKS(domain.get_volume()), + CUDA_NUM_THREADS, + 0, + stream, + w, + domain.get_volume(), + 0.0f); } else if (meta->data_types[i] == DT_INT32) { int32_t *w = helperGetTensorPointerWO( regions[i], task->regions[i], FID_DATA, ctx, runtime); diff --git a/src/runtime/initializer_kernel.cu b/src/runtime/initializer_kernel.cu index b6629ec90b..4c114f5de6 100644 --- a/src/runtime/initializer_kernel.cu +++ b/src/runtime/initializer_kernel.cu @@ -235,6 +235,12 @@ void ZeroInitializer::init_task(Task const *task, assign_kernel <<>>( w, domain.get_volume(), 0.0f); + } else if (meta->data_types[i] == DT_BF16) { + __nv_bfloat16 *w = helperGetTensorPointerWO<__nv_bfloat16>( + regions[i], task->regions[i], FID_DATA, ctx, runtime); + assign_kernel<__nv_bfloat16> + <<>>( + w, domain.get_volume(), 0.0f); } else if (meta->data_types[i] == DT_FLOAT) { float *w = helperGetTensorPointerWO( regions[i], task->regions[i], FID_DATA, ctx, runtime); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index c07c33efca..a826bb5c34 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1947,6 +1947,9 @@ void FFModel::map_tensor_with_dim2(ParallelTensor tensor, case DT_HALF: allocator.allocate_field(sizeof(half), FID_DATA); break; + case DT_BF16: + allocator.allocate_field(sizeof(__ff_bfloat16), FID_DATA); + break; case DT_FLOAT: allocator.allocate_field(sizeof(float), FID_DATA); break; diff --git a/src/runtime/operator.cc b/src/runtime/operator.cc index 0b3813f41c..3cb71f910e 100644 --- a/src/runtime/operator.cc +++ b/src/runtime/operator.cc @@ -62,6 +62,10 @@ void Op::save_inference_tensors_to_file( save_tensor(input_tensors[i].get_half_ptr(), input_tensors[i].domain.get_volume(), filename.c_str()); + } else if (input_tensors[i].data_type == DT_BF16) { + save_tensor(input_tensors[i].get_bfloat16_ptr(), + input_tensors[i].domain.get_volume(), + filename.c_str()); } else if (input_tensors[i].data_type == DT_INT32) { save_tensor(input_tensors[i].get_int32_ptr(), input_tensors[i].domain.get_volume(), @@ -86,6 +90,10 @@ void Op::save_inference_tensors_to_file( save_tensor(weight_tensors[i].get_half_ptr(), weight_tensors[i].domain.get_volume(), filename.c_str()); + } else if (weight_tensors[i].data_type == DT_BF16) { + save_tensor(weight_tensors[i].get_bfloat16_ptr(), + weight_tensors[i].domain.get_volume(), + filename.c_str()); } else if (weight_tensors[i].data_type == DT_INT32) { save_tensor(weight_tensors[i].get_int32_ptr(), weight_tensors[i].domain.get_volume(), @@ -110,6 +118,10 @@ void Op::save_inference_tensors_to_file( save_tensor(output_tensors[i].get_half_ptr(), output_tensors[i].domain.get_volume(), filename.c_str()); + } else if (output_tensors[i].data_type == DT_BF16) { + save_tensor(output_tensors[i].get_bfloat16_ptr(), + output_tensors[i].domain.get_volume(), + filename.c_str()); } else if (output_tensors[i].data_type == DT_INT32) { save_tensor(output_tensors[i].get_int32_ptr(), output_tensors[i].domain.get_volume(), diff --git a/src/runtime/parallel_tensor.cc b/src/runtime/parallel_tensor.cc index 8f1be15fd1..d3cc8a0cb1 100644 --- a/src/runtime/parallel_tensor.cc +++ b/src/runtime/parallel_tensor.cc @@ -847,6 +847,12 @@ template bool ParallelTensorBase::get_tensor(FFModel const *ff, half *data, bool get_gradients); +template bool ParallelTensorBase::set_tensor<__ff_bfloat16>( + FFModel const *ff, std::vector const &dims, __ff_bfloat16 const *data); +template bool ParallelTensorBase::get_tensor<__ff_bfloat16>(FFModel const *ff, + __ff_bfloat16 *data, + bool get_gradients); + template bool ParallelTensorBase::set_tensor(FFModel const *ff, std::vector const &dims, char const *data); diff --git a/tests/inference/huggingface_inference.py b/tests/inference/huggingface_inference.py index 5b533bf3c0..4a8b5920ea 100644 --- a/tests/inference/huggingface_inference.py +++ b/tests/inference/huggingface_inference.py @@ -26,6 +26,9 @@ def main(): parser.add_argument( "--use-full-precision", action="store_true", help="Use full precision" ) + parser.add_argument( + "--use-bfloat16-precision", action="store_true", help="Use bf16 precision" + ) parser.add_argument("--do-sample", action="store_true", help="Use sampling") parser.add_argument("--gpu", action="store_true", help="Run on GPU") args = parser.parse_args() @@ -47,7 +50,11 @@ def main(): return # Set default tensor type depending on argument indicating the float type to use - if not args.use_full_precision: + if args.use_full_precision: + torch.set_default_tensor_type(torch.FloatTensor) + elif args.use_bfloat16_precision: + torch.set_default_tensor_type(torch.BFloat16Tensor) + else: torch.set_default_tensor_type(torch.HalfTensor) # Run huggingface model