Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Datatype: support __nv_bfloat16 #1264

Draft
wants to merge 12 commits into
base: inference
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/flexflow/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@

#if defined(FF_USE_CUDA)
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#elif defined(FF_USE_HIP_CUDA)
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#elif defined(FF_USE_HIP_ROCM)
#include <hip/hip_fp16.h>
#include <hip_bfloat16.h>
#endif

// using namespace Legion;

namespace FlexFlow {

#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
typedef __nv_bfloat16 __ff_bfloat16;
#elif defined(FF_USE_HIP_ROCM)
typedef hip_bfloat16 __ff_bfloat16;
#endif

template <typename FT, int N, typename T = Legion::coord_t>
using AccessorRO =
Legion::FieldAccessor<READ_ONLY, FT, N, T, Realm::AffineAccessor<FT, N, T>>;
Expand Down Expand Up @@ -61,6 +70,7 @@ class GenericTensorAccessorW {
float *get_float_ptr() const;
double *get_double_ptr() const;
half *get_half_ptr() const;
__ff_bfloat16 *get_bfloat16_ptr() const;
char *get_byte_ptr() const;
DataType data_type;
Legion::Domain domain;
Expand All @@ -80,6 +90,7 @@ class GenericTensorAccessorR {
float const *get_float_ptr() const;
double const *get_double_ptr() const;
half const *get_half_ptr() const;
__ff_bfloat16 const *get_bfloat16_ptr() const;
char const *get_byte_ptr() const;
DataType data_type;
Legion::Domain domain;
Expand Down
7 changes: 7 additions & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum DataType {
DT_DOUBLE = 45,
DT_INT4 = 46,
DT_INT8 = 47,
DT_BF16 = 48,
DT_NONE = 49,
};

Expand Down Expand Up @@ -269,4 +270,10 @@ enum {
PARALLEL_TENSOR_GUID_FIRST_VALID = 4000000,
NODE_GUID_FIRST_VALID = 5000000,
};

enum InferencePrecision {
INFERENCE_FLOAT = 800,
INFERENCE_HALF = 801,
INFERENCE_BFLOAT16 = 802,
};
#endif // _FLEXFLOW_CONST_H_
114 changes: 114 additions & 0 deletions include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ struct half8 {
half c;
half d;
};

struct __nv_bfloat164 {
__nv_bfloat16 x;
__nv_bfloat16 y;
__nv_bfloat16 z;
__nv_bfloat16 w;
};

struct __nv_bfloat168 {
__nv_bfloat16 x;
__nv_bfloat16 y;
__nv_bfloat16 z;
__nv_bfloat16 w;
__nv_bfloat16 a;
__nv_bfloat16 b;
__nv_bfloat16 c;
__nv_bfloat16 d;
};

struct float8 {
float x;
float y;
Expand Down Expand Up @@ -61,6 +80,18 @@ template <>
struct VEC_K<half, 4> {
using Type = half4;
};
template <>
struct VEC_K<__nv_bfloat16, 1> {
using Type = __nv_bfloat16;
};
template <>
struct VEC_K<__nv_bfloat16, 2> {
using Type = __nv_bfloat162;
};
template <>
struct VEC_K<__nv_bfloat16, 4> {
using Type = __nv_bfloat164;
};

// data type for QK production
template <typename T>
Expand Down Expand Up @@ -95,6 +126,23 @@ struct Vec_fp32_<half8> {
using Type = float8;
};

template <>
struct Vec_fp32_<__nv_bfloat16> {
using Type = float;
};
template <>
struct Vec_fp32_<__nv_bfloat162> {
using Type = float2;
};
template <>
struct Vec_fp32_<__nv_bfloat164> {
using Type = float4;
};
template <>
struct Vec_fp32_<__nv_bfloat168> {
using Type = float8;
};

template <typename DT>
struct VEC_V {};
template <>
Expand All @@ -105,6 +153,10 @@ template <>
struct VEC_V<half> {
using Type = half8;
};
template <>
struct VEC_V<__nv_bfloat16> {
using Type = __nv_bfloat168;
};

////////////////data structures half///////////////

Expand Down Expand Up @@ -331,6 +383,42 @@ inline __device__ float8 cast_to_float(half8 u) {
return tmp;
}

inline __device__ float cast_to_float(__nv_bfloat16 u) {
return __bfloat162float(u);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float2 cast_to_float(__nv_bfloat162 u) {
float2 tmp;
tmp.x = __bfloat162float(u.x);
tmp.y = __bfloat162float(u.y);
return tmp;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float4 cast_to_float(__nv_bfloat164 u) {
float4 tmp;
tmp.x = __bfloat162float(u.x);
tmp.y = __bfloat162float(u.y);
tmp.z = __bfloat162float(u.z);
tmp.w = __bfloat162float(u.w);
return tmp;
}
inline __device__ float8 cast_to_float(__nv_bfloat168 u) {
float8 tmp;
tmp.x = __bfloat162float(u.x);
tmp.y = __bfloat162float(u.y);
tmp.z = __bfloat162float(u.z);
tmp.w = __bfloat162float(u.w);
tmp.a = __bfloat162float(u.a);
tmp.b = __bfloat162float(u.b);
tmp.c = __bfloat162float(u.c);
tmp.d = __bfloat162float(u.d);
return tmp;
}

inline __device__ void convert_from_float(float4 &dst, float4 src) {
dst = src;
}
Expand Down Expand Up @@ -369,6 +457,32 @@ inline __device__ void convert_from_float(half &dst, float src) {
dst = __float2half(src);
}

//
inline __device__ void convert_from_float(__nv_bfloat164 &dst, float4 src) {
dst.x = __float2bfloat16(src.x);
dst.y = __float2bfloat16(src.y);
dst.z = __float2bfloat16(src.z);
dst.w = __float2bfloat16(src.w);
}

inline __device__ void convert_from_float(__nv_bfloat168 &dst, float8 src) {
dst.x = __float2bfloat16(src.x);
dst.y = __float2bfloat16(src.y);
dst.z = __float2bfloat16(src.z);
dst.w = __float2bfloat16(src.w);
dst.a = __float2bfloat16(src.a);
dst.b = __float2bfloat16(src.b);
dst.c = __float2bfloat16(src.c);
dst.d = __float2bfloat16(src.d);
}
inline __device__ void convert_from_float(__nv_bfloat162 &dst, float2 src) {
dst.x = __float2bfloat16(src.x);
dst.y = __float2bfloat16(src.y);
}
inline __device__ void convert_from_float(__nv_bfloat16 &dst, float src) {
dst = __float2bfloat16(src);
}

//////////////////////////////////////utils///////////////////////////////////////////////

template <typename T>
Expand Down
14 changes: 14 additions & 0 deletions include/flexflow/utils/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ T *download_tensor(T const *ptr, size_t num_elements);
template <typename T>
bool download_tensor(T const *ptr, T *dst, size_t num_elements);

// data type for cublasgemm
template <typename T>
struct cublasAlphaBetaType {
using type = float; // default
};
template <>
struct cublasAlphaBetaType<half> {
using type = half;
};
template <>
struct cublasAlphaBetaType<__nv_bfloat16> {
using type = float;
};

cudnnStatus_t cudnnSetTensorDescriptorFromDomain(cudnnTensorDescriptor_t tensor,
Legion::Domain domain,
DataType data_type = DT_FLOAT);
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/utils/file_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class FileDataLoader {
template <typename DT>
void load_single_weight_tensor(FFModel *ff, Layer *l, int weight_idx);

void load_single_weight_tensor_b16(FFModel *ff, Layer *l, int weight_idx);

void load_quantization_weight(FFModel *ff, Layer *l, int weight_idx);
void load_weights(FFModel *ff);

Expand All @@ -46,6 +48,8 @@ class FileDataLoader {
ParallelTensor position_pt,
int max_seq_length,
int offset);
// template <typename DT>
// void load_from_file(DT *ptr, size_t size, std::string filepath);

private:
int num_heads, num_kv_heads, tensor_parallelism_degree;
Expand Down
46 changes: 28 additions & 18 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ void parse_input_args(char **argv,
int argc,
FilePaths &paths,
std::string &llm_model_name,
DataType &data_type,
bool &use_full_precision,
bool &use_bfloat16_precision,
bool &verbose,
bool &do_sample,
float &temperature,
Expand Down Expand Up @@ -74,6 +76,12 @@ void parse_input_args(char **argv,
}
if (!strcmp(argv[i], "--use-full-precision")) {
use_full_precision = true;
data_type = DT_FLOAT;
continue;
}
if (!strcmp(argv[i], "--use-bfloat16-precision")) {
use_bfloat16_precision = true;
data_type = DT_BF16;
continue;
}
// verbose logging to stdout
Expand Down Expand Up @@ -126,7 +134,9 @@ void FlexFlow::top_level_task(Task const *task,
}
FilePaths file_paths;
std::string llm_model_name;
DataType data_type = DT_HALF;
bool use_full_precision = false;
bool use_bfloat16_precision = false;
bool verbose = false;
bool do_sample = false;
float temperature = 0.0f;
Expand All @@ -142,7 +152,9 @@ void FlexFlow::top_level_task(Task const *task,
argc,
file_paths,
llm_model_name,
data_type,
use_full_precision,
use_bfloat16_precision,
verbose,
do_sample,
temperature,
Expand All @@ -159,11 +171,15 @@ void FlexFlow::top_level_task(Task const *task,
{file_paths.cache_folder_path, "configs", llm_model_name, "config.json"});
std::string tokenizer_filepath =
join_path({file_paths.cache_folder_path, "tokenizers", llm_model_name});
std::string weights_filepath =
join_path({file_paths.cache_folder_path,
"weights",
llm_model_name,
use_full_precision ? "full-precision" : "half-precision"});

// bfloat16 shares same weight file with float32
std::string weights_filepath = join_path(
{file_paths.cache_folder_path,
"weights",
llm_model_name,
use_full_precision
? "full-precision"
: (use_bfloat16_precision ? "full-precision" : "half-precision")});
std::ifstream config_file_handle(config_filepath);
if (!config_file_handle.good()) {
std::cout << "Model config file " << config_filepath << " not found."
Expand Down Expand Up @@ -220,33 +236,27 @@ void FlexFlow::top_level_task(Task const *task,
weights_filepath,
INC_DECODING_MODE,
generationConfig,
use_full_precision);
data_type);
} else if (model_type == ModelType::OPT) {
OPT::create_opt_model(model,
config_filepath,
weights_filepath,
INC_DECODING_MODE,
use_full_precision);
OPT::create_opt_model(
model, config_filepath, weights_filepath, INC_DECODING_MODE, data_type);
} else if (model_type == ModelType::FALCON) {
FALCON::create_falcon_model(model,
config_filepath,
weights_filepath,
INC_DECODING_MODE,
use_full_precision);
FALCON::create_falcon_model(
model, config_filepath, weights_filepath, INC_DECODING_MODE, data_type);
} else if (model_type == ModelType::STARCODER) {
STARCODER::create_starcoder_model(model,
config_filepath,
weights_filepath,
INC_DECODING_MODE,
generationConfig,
use_full_precision);
data_type);
} else if (model_type == ModelType::MPT) {
MPT::create_mpt_model(model,
config_filepath,
weights_filepath,
INC_DECODING_MODE,
generationConfig,
use_full_precision);
data_type);
} else {
assert(false && "unknow model type");
}
Expand Down
8 changes: 4 additions & 4 deletions inference/models/falcon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void FALCON::create_falcon_model(FFModel &ff,
std::string const &model_config_file_path,
std::string const &weight_file_path,
InferenceMode mode,
bool use_full_precision) {
DataType data_type) {
FalconConfig falcon_config(model_config_file_path);
falcon_config.print();

Expand Down Expand Up @@ -54,7 +54,7 @@ void FALCON::create_falcon_model(FFModel &ff,
falcon_config.vocab_size,
falcon_config.hidden_size,
AGGR_MODE_NONE,
use_full_precision ? DT_FLOAT : DT_HALF,
data_type,
NULL,
embed_init,
"word_embeddings");
Expand Down Expand Up @@ -248,7 +248,7 @@ void FALCON::create_falcon_model(FFModel &ff,
falcon_config.hidden_size,
falcon_config.hidden_size / falcon_config.n_head,
ff.config.tensor_parallelism_degree,
use_full_precision);
true);

InferenceManager *im = InferenceManager::get_inference_manager();
im->register_model_weights_loader(&ff, fileloader);
Expand All @@ -266,7 +266,7 @@ void FALCON::create_falcon_model(FFModel &ff,
falcon_config.hidden_size / falcon_config.n_head,
ff.config.tensor_parallelism_degree);
std::cout << "------load weights ----------" << std::endl;
fileloader.load_weights(&ff, use_full_precision);
fileloader.load_weights(&ff, false);
std::cout << "------load weight finished----------" << std::endl;

// init operators
Expand Down
Loading
Loading