From 50ff264ee02956bd464e575c70011ed9550bef56 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sun, 8 Oct 2023 19:39:01 -0400 Subject: [PATCH] [Tool] - Add mechanism to save operators' tensors to file (#1174) * add model id, layer_id and op_name to opmeta * pass model id to opmeta * . * implement inference tensor save function * add calls to save tensors function in ops * more ops * done * fix bugs, implement batchconfig << operator, add function to save bc to file * fixes * hip_rocm fixes * fix * fix bug * fix ci * removed out of date incmha inference test * add save tensors function to fused.cu --- .gitignore | 2 + conda/pytorch-gpu.yml | 2 +- include/flexflow/batch_config.h | 13 +- include/flexflow/config.h | 1 + include/flexflow/fftype.h | 4 +- include/flexflow/layer.h | 1 + include/flexflow/model.h | 5 + include/flexflow/op_meta.h | 4 + include/flexflow/operator.h | 10 + .../ops/add_bias_residual_layer_norm.h | 1 - include/flexflow/ops/element_unary.h | 1 - .../ops/inc_multihead_self_attention.h | 3 - include/flexflow/ops/kernels/concat_kernels.h | 1 - .../flexflow/ops/kernels/conv_2d_kernels.h | 1 - .../ops/kernels/element_binary_kernels.h | 1 - include/flexflow/ops/kernels/linear_kernels.h | 1 - .../flexflow/ops/kernels/pool_2d_kernels.h | 1 - .../ops/kernels/residual_rms_norm_kernels.h | 1 - .../flexflow/ops/kernels/rms_norm_kernels.h | 1 - .../flexflow/ops/kernels/softmax_kernels.h | 2 +- include/flexflow/ops/layer_norm.h | 1 - include/flexflow/ops/linear.h | 1 + include/flexflow/ops/residual_layer_norm.h | 1 - include/flexflow/ops/residual_rms_norm.h | 1 - include/flexflow/ops/rms_norm.h | 1 - include/flexflow/ops/sigmoid_silu_multi.h | 1 - include/flexflow/utils/hip_helper.h | 3 + inference/python/incr_decoding.py | 1 + inference/python/spec_infer.py | 1 + python/flexflow/core/__init__.py | 1 + python/flexflow/serve/__init__.py | 8 + src/c/flexflow_c.cc | 3 +- src/ops/add_bias_residual_layer_norm.cc | 65 +- src/ops/add_bias_residual_layer_norm.cpp | 1 + src/ops/add_bias_residual_layer_norm.cu | 1 + src/ops/aggregate.cc | 3 + src/ops/aggregate_spec.cc | 3 + src/ops/arg_topk.cc | 18 +- src/ops/argmax.cc | 21 +- src/ops/attention.cc | 3 + src/ops/batch_matmul.cc | 3 + src/ops/batch_norm.cpp | 1 + src/ops/batch_norm.cu | 1 + src/ops/beam_topk.cc | 78 +- src/ops/cache.cc | 3 + src/ops/cast.cc | 2 + src/ops/concat.cc | 2 + src/ops/conv_2d.cc | 8 +- src/ops/dropout.cc | 2 + src/ops/element_binary.cc | 19 +- src/ops/element_unary.cc | 31 +- src/ops/embedding.cc | 77 +- src/ops/experts.cc | 39 +- src/ops/fused.cu | 25 + src/ops/gather.cc | 8 +- src/ops/group_by.cc | 41 +- src/ops/inc_multihead_self_attention.cc | 797 +----------------- src/ops/inc_multihead_self_attention.cu | 15 - src/ops/kernels/dropout_kernels.cpp | 1 + src/ops/kernels/dropout_kernels.cu | 1 + src/ops/kernels/element_binary_kernels.cu | 1 + src/ops/kernels/softmax.cpp | 1 + src/ops/kernels/softmax.cu | 1 + src/ops/layer_norm.cc | 25 +- src/ops/layer_norm.cu | 1 + src/ops/linear.cc | 22 +- src/ops/pool_2d.cc | 2 + src/ops/reduce.cc | 8 +- src/ops/reshape.cc | 8 +- src/ops/residual_layer_norm.cc | 36 +- src/ops/residual_layer_norm.cpp | 1 + src/ops/residual_layer_norm.cu | 1 + src/ops/residual_rms_norm.cc | 17 +- src/ops/rms_norm.cc | 16 +- src/ops/sampling.cc | 12 +- src/ops/sigmoid_silu_multi.cc | 17 +- src/ops/sigmoid_silu_multi.cpp | 1 + src/ops/sigmoid_silu_multi.cu | 1 + src/ops/softmax.cc | 89 +- src/ops/spec_inc_multihead_self_attention.cc | 25 +- src/ops/topk.cc | 3 + src/ops/transpose.cc | 3 + src/ops/tree_inc_multihead_self_attention.cc | 16 +- src/runtime/batch_config.cc | 84 +- src/runtime/beam_search_batch_config.cc | 140 +-- src/runtime/cuda_helper.cu | 104 ++- src/runtime/fftype.cc | 13 +- src/runtime/graph.cc | 30 +- src/runtime/hip_helper.cpp | 117 +++ src/runtime/layer.cc | 6 +- src/runtime/model.cc | 20 +- src/runtime/operator.cc | 109 +++ src/runtime/tree_verify_batch_config.cc | 100 ++- .../python_test_configs/generate_configs.py | 1 + 94 files changed, 1109 insertions(+), 1270 deletions(-) diff --git a/.gitignore b/.gitignore index be0266c9b5..8fcc105f01 100644 --- a/.gitignore +++ b/.gitignore @@ -185,3 +185,5 @@ gpt_tokenizer # pip version python/flexflow/version.txt + +inference_tensors \ No newline at end of file diff --git a/conda/pytorch-gpu.yml b/conda/pytorch-gpu.yml index 677e71d73f..85d24ced17 100644 --- a/conda/pytorch-gpu.yml +++ b/conda/pytorch-gpu.yml @@ -3,7 +3,7 @@ channels: - defaults - conda-forge dependencies: - - python>=3.6 + - python>=3.6,<3.12 - pip - pip: - numpy>=1.16.0 diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 6dabc70f4b..108bc8d172 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -46,7 +46,9 @@ class BatchConfig { static int max_requests_per_batch(); static int max_tokens_per_batch(); static int max_sequence_length(); + friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc); void print() const; + void save_to_file(std::string const &filename) const; virtual InferenceMode get_mode() const; static BatchConfig const *from_future(BatchConfigFuture const &future); // Maximum possible values for different parameters @@ -55,9 +57,8 @@ class BatchConfig { static int const MAX_NUM_REQUESTS = 64; static int const MAX_NUM_TOKENS = 1024; - // These are set by update + // Set by update int num_tokens; - bool loading_prompt = false; struct PerRequestInfo { int token_start_offset; @@ -74,7 +75,7 @@ class BatchConfig { PerTokenInfo tokensInfo[MAX_NUM_TOKENS]; bool request_completed[MAX_NUM_REQUESTS]; - bool request_running[MAX_NUM_TOKENS]; + bool request_running[MAX_NUM_REQUESTS]; }; class TreeVerifyBatchConfig : public BatchConfig { @@ -82,7 +83,10 @@ class TreeVerifyBatchConfig : public BatchConfig { TreeVerifyBatchConfig(); ~TreeVerifyBatchConfig(); InferenceMode get_mode() const; + friend std::ostream &operator<<(std::ostream &os, + TreeVerifyBatchConfig const &bc); void print() const; + void save_to_file(std::string const &filename) const; struct CommittedTokensInfo { int token_index; // the index of the token in the previous batch int request_index; // request index in the batch @@ -108,7 +112,10 @@ class BeamSearchBatchConfig : public BatchConfig { ~BeamSearchBatchConfig(); + friend std::ostream &operator<<(std::ostream &os, + BeamSearchBatchConfig const &bc); void print() const; + void save_to_file(std::string const &filename) const; bool done() const; int max_beam_depth_all_requests() const; int current_depth_all_requests() const; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index 9716060173..c2af6d707c 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -134,6 +134,7 @@ class FFConfig { Legion::Runtime *lg_hlr; // Legion::FieldSpace field_space; bool syntheticInput, profiling, perform_fusion; + bool inference_debugging; size_t simulator_work_space_size; size_t search_budget; float search_alpha; diff --git a/include/flexflow/fftype.h b/include/flexflow/fftype.h index 18ed6b8100..1cd90fda26 100644 --- a/include/flexflow/fftype.h +++ b/include/flexflow/fftype.h @@ -10,12 +10,12 @@ class LayerID { public: static const LayerID NO_ID; LayerID(); - LayerID(size_t id, size_t transformer_layer_id); + LayerID(size_t id, size_t transformer_layer_id, size_t model_id); bool is_valid_id() const; friend bool operator==(LayerID const &lhs, LayerID const &rhs); public: - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, model_id; }; }; // namespace FlexFlow diff --git a/include/flexflow/layer.h b/include/flexflow/layer.h index 0c1d7a6092..69a57e4e1c 100644 --- a/include/flexflow/layer.h +++ b/include/flexflow/layer.h @@ -52,6 +52,7 @@ class Layer { bool trainableInputs[MAX_NUM_INPUTS]; int numInputs, numWeights, numOutputs; bool profiling; + bool inference_debugging; private: std::unordered_map int_properties; diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 97ee553fb3..d8402ba622 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -1234,6 +1234,8 @@ class FFModel { std::unordered_map cached_noop_ops; std::unordered_map cached_input_ops; std::vector all_valid_views; + int model_id; // unique incremental id assigned to each model. Used in the + // inference_debugging mode. #ifdef FF_USE_NCCL std::unordered_map view_hash_to_nccl_comms; #endif @@ -1262,6 +1264,9 @@ class FFModel { ElementUnary * unary(OperatorType op, char const *name = NULL, float scalar = 0.0); PCG::Node new_node(Op *); + static int model_counter; // number of instantiated FFModel objects. Used to + // assign a unique incremental id to each model. + // Used in the inference_debugging mode. }; class UtilityTasks { diff --git a/include/flexflow/op_meta.h b/include/flexflow/op_meta.h index 512844db92..60785a1e29 100644 --- a/include/flexflow/op_meta.h +++ b/include/flexflow/op_meta.h @@ -15,6 +15,10 @@ class OpMeta { public: FFHandler handle; bool profiling; // Measure the run time of the task + bool inference_debugging; + int decoding_step; + char op_name[MAX_OPNAME]; + LayerID layer_guid; bool trainableInputs[MAX_NUM_INPUTS]; DataType input_type[MAX_NUM_INPUTS]; DataType weight_type[MAX_NUM_WEIGHTS]; diff --git a/include/flexflow/operator.h b/include/flexflow/operator.h index 1b2fc7bbfc..fd21436681 100644 --- a/include/flexflow/operator.h +++ b/include/flexflow/operator.h @@ -1,6 +1,7 @@ #ifndef _OPERATOR_H #define _OPERATOR_H +#include "flexflow/accessor.h" #include "flexflow/batch_config.h" #include "flexflow/fftype.h" #include "flexflow/machine_view.h" @@ -183,6 +184,7 @@ class Op { const ParallelTensor input4 = NULL); Op(int guid, bool profiling, + bool inference_debugging, OperatorType otype, DataType dtype, char const *name, @@ -225,6 +227,13 @@ class Op { assert(false); }; virtual void print_layer(FFModel const &model) = 0; + static void save_inference_tensors_to_file( + OpMeta *m, + int shard_id, + BatchConfig const *bc, + std::vector input_tensors, + std::vector weight_tensors, + std::vector output_tensors); virtual bool measure_operator_cost(Simulator *sim, MachineView const &mv, CostMetrics &cost_metrics) const = 0; @@ -316,6 +325,7 @@ class Op { std::map inference_meta; int numInputs, numWeights, numOutputs; bool profiling; + bool inference_debugging; bool add_bias_only_once; #ifdef FF_USE_NCCL ncclUniqueId ncclId; diff --git a/include/flexflow/ops/add_bias_residual_layer_norm.h b/include/flexflow/ops/add_bias_residual_layer_norm.h index 523f4d3b7c..bb470376c3 100644 --- a/include/flexflow/ops/add_bias_residual_layer_norm.h +++ b/include/flexflow/ops/add_bias_residual_layer_norm.h @@ -106,7 +106,6 @@ class AddBiasResidualLayerNormMeta : public OpMeta { int64_t effective_batch_size, effective_num_elements; float eps; void *mean_ptr, *rstd_ptr, *ds_ptr, *db_ptr, *scale_ptr, *bias_ptr; - char op_name[MAX_OPNAME]; Realm::RegionInstance reserveInst; }; diff --git a/include/flexflow/ops/element_unary.h b/include/flexflow/ops/element_unary.h index f82db5f910..ddef59549c 100644 --- a/include/flexflow/ops/element_unary.h +++ b/include/flexflow/ops/element_unary.h @@ -26,7 +26,6 @@ class ElementUnaryMeta : public OpMeta { DataType data_type; bool inplace; float scalar; - char op_name[MAX_OPNAME]; }; class ElementUnary : public Op { diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index 8290998f02..51a3b9fbe1 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -185,9 +185,6 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { bool *qk_prod_scaling; bool *position_bias; float scaling_factor; -#ifdef INFERENCE_TESTS - float *kcache, *vcache; -#endif void *weight_ptr, *bias_ptr; // for weight offload void *devQKVProjArray, *keyCache, *valueCache; void *qk_prods, *qk_prods_softmax; diff --git a/include/flexflow/ops/kernels/concat_kernels.h b/include/flexflow/ops/kernels/concat_kernels.h index 755e1800da..4da6aaf5e2 100644 --- a/include/flexflow/ops/kernels/concat_kernels.h +++ b/include/flexflow/ops/kernels/concat_kernels.h @@ -12,7 +12,6 @@ class ConcatMeta : public OpMeta { public: ConcatMeta(FFHandler handle) : OpMeta(handle){}; int legion_axis; - char op_name[MAX_OPNAME]; }; namespace Kernels { diff --git a/include/flexflow/ops/kernels/conv_2d_kernels.h b/include/flexflow/ops/kernels/conv_2d_kernels.h index a848d83d60..7b2a0fe135 100644 --- a/include/flexflow/ops/kernels/conv_2d_kernels.h +++ b/include/flexflow/ops/kernels/conv_2d_kernels.h @@ -28,7 +28,6 @@ class Conv2DMeta : public OpMeta { miopenConvBwdDataAlgorithm_t bwdDataAlgo; #endif bool relu, use_bias; - char op_name[MAX_OPNAME]; }; namespace Kernels { diff --git a/include/flexflow/ops/kernels/element_binary_kernels.h b/include/flexflow/ops/kernels/element_binary_kernels.h index b0c596301b..5a375fb661 100644 --- a/include/flexflow/ops/kernels/element_binary_kernels.h +++ b/include/flexflow/ops/kernels/element_binary_kernels.h @@ -23,7 +23,6 @@ class ElementBinaryMeta : public OpMeta { OperatorType op_type; bool inplace_a, has_same_operands; bool broadcast_input1, broadcast_input2; - char op_name[MAX_OPNAME]; }; namespace Kernels { diff --git a/include/flexflow/ops/kernels/linear_kernels.h b/include/flexflow/ops/kernels/linear_kernels.h index bbebe3c79b..a5fdc7c602 100644 --- a/include/flexflow/ops/kernels/linear_kernels.h +++ b/include/flexflow/ops/kernels/linear_kernels.h @@ -34,7 +34,6 @@ class LinearMeta : public OpMeta { RegularizerMode kernel_reg_type; float kernel_reg_lambda; bool use_bias, add_bias_only_once; - char op_name[MAX_OPNAME]; Realm::RegionInstance reserveInst; }; diff --git a/include/flexflow/ops/kernels/pool_2d_kernels.h b/include/flexflow/ops/kernels/pool_2d_kernels.h index ac86cb50c3..7f73a8295d 100644 --- a/include/flexflow/ops/kernels/pool_2d_kernels.h +++ b/include/flexflow/ops/kernels/pool_2d_kernels.h @@ -14,7 +14,6 @@ class Pool2DMeta : public OpMeta { ffActivationDescriptor_t actiDesc; ffPoolingDescriptor_t poolDesc; bool relu; - char op_name[MAX_OPNAME]; }; namespace Kernels { diff --git a/include/flexflow/ops/kernels/residual_rms_norm_kernels.h b/include/flexflow/ops/kernels/residual_rms_norm_kernels.h index aa454711ec..0eef4ca72b 100644 --- a/include/flexflow/ops/kernels/residual_rms_norm_kernels.h +++ b/include/flexflow/ops/kernels/residual_rms_norm_kernels.h @@ -37,7 +37,6 @@ class ResidualRMSNormMeta : public OpMeta { int in_dim; int batch_size; int num_elements; - char op_name[MAX_OPNAME]; Realm::RegionInstance reserveInst; }; diff --git a/include/flexflow/ops/kernels/rms_norm_kernels.h b/include/flexflow/ops/kernels/rms_norm_kernels.h index 2063777ef1..35c5aa69fa 100644 --- a/include/flexflow/ops/kernels/rms_norm_kernels.h +++ b/include/flexflow/ops/kernels/rms_norm_kernels.h @@ -37,7 +37,6 @@ class RMSNormMeta : public OpMeta { int in_dim; int batch_size; int num_elements; - char op_name[MAX_OPNAME]; Realm::RegionInstance reserveInst; }; diff --git a/include/flexflow/ops/kernels/softmax_kernels.h b/include/flexflow/ops/kernels/softmax_kernels.h index 987a546459..8cfaf3c586 100644 --- a/include/flexflow/ops/kernels/softmax_kernels.h +++ b/include/flexflow/ops/kernels/softmax_kernels.h @@ -21,8 +21,8 @@ class SoftmaxMeta : public OpMeta { miopenTensorDescriptor_t outputTensor; #endif bool profiling; + bool inference_debugging; int dim; - char op_name[MAX_OPNAME]; DataType input_type, output_type; }; diff --git a/include/flexflow/ops/layer_norm.h b/include/flexflow/ops/layer_norm.h index c65370e0fd..9e48d81190 100644 --- a/include/flexflow/ops/layer_norm.h +++ b/include/flexflow/ops/layer_norm.h @@ -123,7 +123,6 @@ class LayerNormMeta : public OpMeta { int64_t effective_batch_size, effective_num_elements; float eps; void *mean_ptr, *rstd_ptr, *ds_ptr, *db_ptr, *scale_ptr, *bias_ptr; - char op_name[MAX_OPNAME]; Realm::RegionInstance reserveInst; }; diff --git a/include/flexflow/ops/linear.h b/include/flexflow/ops/linear.h index 025674c7ba..a32df80537 100644 --- a/include/flexflow/ops/linear.h +++ b/include/flexflow/ops/linear.h @@ -96,6 +96,7 @@ class Linear : public Op { private: Linear(int guid, bool profiling, + bool inference_debugging, const ParallelTensor input, int out_dim, ActiMode activation, diff --git a/include/flexflow/ops/residual_layer_norm.h b/include/flexflow/ops/residual_layer_norm.h index 39f149554e..0e9be82125 100644 --- a/include/flexflow/ops/residual_layer_norm.h +++ b/include/flexflow/ops/residual_layer_norm.h @@ -106,7 +106,6 @@ class ResidualLayerNormMeta : public OpMeta { int64_t effective_batch_size, effective_num_elements; float eps; void *mean_ptr, *rstd_ptr, *ds_ptr, *db_ptr, *scale_ptr, *bias_ptr; - char op_name[MAX_OPNAME]; Realm::RegionInstance reserveInst; }; diff --git a/include/flexflow/ops/residual_rms_norm.h b/include/flexflow/ops/residual_rms_norm.h index 783173c5eb..0d92a236e8 100644 --- a/include/flexflow/ops/residual_rms_norm.h +++ b/include/flexflow/ops/residual_rms_norm.h @@ -80,7 +80,6 @@ class ResidualRMSNorm : public Op { public: float eps; - char op_name[MAX_OPNAME]; int effective_batch_size; int dim, data_dim; }; diff --git a/include/flexflow/ops/rms_norm.h b/include/flexflow/ops/rms_norm.h index a3074de015..1dc940ebd3 100644 --- a/include/flexflow/ops/rms_norm.h +++ b/include/flexflow/ops/rms_norm.h @@ -79,7 +79,6 @@ class RMSNorm : public Op { public: float eps; - char op_name[MAX_OPNAME]; int effective_batch_size; int dim, data_dim; }; diff --git a/include/flexflow/ops/sigmoid_silu_multi.h b/include/flexflow/ops/sigmoid_silu_multi.h index 6a69288607..604438260a 100644 --- a/include/flexflow/ops/sigmoid_silu_multi.h +++ b/include/flexflow/ops/sigmoid_silu_multi.h @@ -79,7 +79,6 @@ class SigmoidSiluMultiMeta : public OpMeta { ~SigmoidSiluMultiMeta(void); public: - char op_name[MAX_OPNAME]; Realm::RegionInstance reserveInst; }; diff --git a/include/flexflow/utils/hip_helper.h b/include/flexflow/utils/hip_helper.h index be6f4a713d..5d3c831d4f 100644 --- a/include/flexflow/utils/hip_helper.h +++ b/include/flexflow/utils/hip_helper.h @@ -137,6 +137,9 @@ __host__ void updateGAS(float *para_ptr, template void print_tensor(T const *ptr, size_t num_elements, char const *prefix); +template +void save_tensor(T const *ptr, size_t num_elements, char const *file_name); + template T *download_tensor(T const *ptr, size_t num_elements); diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index d8a494b4d5..3621ee83a3 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -55,6 +55,7 @@ def get_configs(): "use_4bit_quantization": False, "use_8bit_quantization": False, "profiling": False, + "inference_debugging": False, "fusion": True, } llm_configs = { diff --git a/inference/python/spec_infer.py b/inference/python/spec_infer.py index c9e87bd29f..3d0f1a1c0e 100644 --- a/inference/python/spec_infer.py +++ b/inference/python/spec_infer.py @@ -55,6 +55,7 @@ def get_configs(): "use_4bit_quantization": False, "use_8bit_quantization": False, "profiling": False, + "inference_debugging": False, "fusion": True, } llm_configs = { diff --git a/python/flexflow/core/__init__.py b/python/flexflow/core/__init__.py index ace6030a1b..d7b1a595d2 100644 --- a/python/flexflow/core/__init__.py +++ b/python/flexflow/core/__init__.py @@ -41,6 +41,7 @@ "num_cpus": "-ll:cpu", "legion_utility_processors": "-ll:util", "profiling": "--profiling", + "inference_debugging": "--inference-debugging", "fusion": "--fusion", "disable_control_replication": "--disable-control-replication", # Training args diff --git a/python/flexflow/serve/__init__.py b/python/flexflow/serve/__init__.py index 9b282ae5f4..cf467280bd 100644 --- a/python/flexflow/serve/__init__.py +++ b/python/flexflow/serve/__init__.py @@ -45,6 +45,7 @@ def init( use_4bit_quantization: Optional[bool] = None, use_8bit_quantization: Optional[bool] = None, profiling: Optional[bool] = None, + inference_debugging: Optional[bool] = None, fusion: Optional[bool] = None, ): """ @@ -71,6 +72,7 @@ def init( - use_4bit_quantization: whether to use 4-bit quantization, defaults to False - use_8bit_quantization: whether to use 8-bit quantization, defaults to False - profiling: whether to enable the FlexFlow profiling mode, defaults to False + - inference_debugging: whether to run inference in debugging mode, saving all inputs/outputs/weights to file, defaults to False - fusion: whether to enable the FlexFlow operator fusion optimization, defaults to True The configurations are passed down to the FlexFlow runtime (implemented in C++) via command line arguments. @@ -104,6 +106,8 @@ def init( :type use_8bit_quantization: Optional[bool], optional :param profiling: whether to enable the FlexFlow profiling mode, defaults to False :type profiling: Optional[bool], optional + :param inference_debugging: whether to run inference in debugging mode, saving all inputs/outputs/weights to file, defaults to False + :type inference_debugging: Optional[bool], optional :param fusion: whether to enable the FlexFlow operator fusion optimization, defaults to True :type fusion: Optional[bool], optional @@ -128,6 +132,7 @@ def init( use_4bit_quantization is not None, use_8bit_quantization is not None, profiling is not None, + inference_debugging is not None, fusion is not None, ] ): @@ -152,6 +157,7 @@ def init( "use_4bit_quantization": use_4bit_quantization, "use_8bit_quantization": use_8bit_quantization, "profiling": profiling, + "inference_debugging": inference_debugging, "fusion": fusion, } @@ -195,6 +201,8 @@ def init( configs_dict["use_8bit_quantization"] = False if configs_dict.get("profiling", None) is None: configs_dict["profiling"] = False + if configs_dict.get("inference_debugging", None) is None: + configs_dict["inference_debugging"] = False if configs_dict.get("fusion", None) is None: configs_dict["fusion"] = True diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 5bb5249f5d..80202f6f99 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1592,7 +1592,8 @@ flexflow_generation_result_t std::string const text_str(input_text); prompts.push_back(input_text); GenerationResult result = handle->generate(prompts, max_seq_length); - DEBUG_PRINT("[Model] generate %p %s %i", handle, text_str, max_seq_length); + DEBUG_PRINT( + "[Model] generate %p %s %i", handle, text_str.c_str(), max_seq_length); assert(result.output_tokens.size() <= max_seq_length); output_length_and_tokens[0] = result.output_tokens.size(); std::copy(result.output_tokens.begin(), diff --git a/src/ops/add_bias_residual_layer_norm.cc b/src/ops/add_bias_residual_layer_norm.cc index 159c82b346..42fbb3016a 100644 --- a/src/ops/add_bias_residual_layer_norm.cc +++ b/src/ops/add_bias_residual_layer_norm.cc @@ -508,6 +508,8 @@ OpMeta *AddBiasResidualLayerNorm::init_task( } meta->output_type[0] = ln->outputs[0]->data_type; meta->output_type[1] = ln->outputs[1]->data_type; + std::strcpy(meta->op_name, ln->name); + meta->layer_guid = ln->layer_guid; return meta; } @@ -620,7 +622,7 @@ void AddBiasResidualLayerNorm::inference_task( return; } - AddBiasResidualLayerNormMeta const *m = + AddBiasResidualLayerNormMeta *m = *((AddBiasResidualLayerNormMeta **)task->local_args); assert(regions.size() == @@ -669,40 +671,6 @@ void AddBiasResidualLayerNorm::inference_task( assert(in_domain.get_volume() == m->effective_num_elements * m->effective_batch_size); - // std::cout << std::endl << "INFERENCE task tensor dims:" << std::endl; - // std::cout << "input: "; - // for (int i=0; ielementwise_affine) { gamma = helperGetGenericTensorAccessorRO(m->weight_type[1], regions[5], @@ -749,6 +717,26 @@ void AddBiasResidualLayerNorm::inference_task( attn_bias, gamma, beta); + + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector weights_accessors; + weights_accessors.push_back(attn_bias); + if (m->elementwise_affine) { + weights_accessors.push_back(gamma); + if (m->use_bias) { + weights_accessors.push_back(beta); + } + } + AddBiasResidualLayerNorm::save_inference_tensors_to_file( + m, + shard_id, + bc, + {input, residual}, + weights_accessors, + {added_output, output}); + } } bool AddBiasResidualLayerNorm::measure_operator_cost( @@ -759,6 +747,7 @@ bool AddBiasResidualLayerNorm::measure_operator_cost( void AddBiasResidualLayerNorm::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->axes.size()); for (size_t i = 0; i < this->axes.size(); i++) { sez.serialize(this->axes[i]); @@ -780,10 +769,11 @@ Node AddBiasResidualLayerNorm::deserialize(FFModel &ff, bool elementwise_affine; bool use_bias; float eps; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(num_axes); for (size_t i = 0; i < num_axes; i++) { int axis_idx; @@ -812,6 +802,7 @@ size_t hash::operator()( size_t key = 0; hash_combine(key, params.layer_guid.id); hash_combine(key, params.layer_guid.transformer_layer_id); + hash_combine(key, params.layer_guid.model_id); hash_combine(key, params.axes.size()); for (int n : params.axes) { hash_combine(key, n); diff --git a/src/ops/add_bias_residual_layer_norm.cpp b/src/ops/add_bias_residual_layer_norm.cpp index 3570ae42dc..1add43ecd9 100644 --- a/src/ops/add_bias_residual_layer_norm.cpp +++ b/src/ops/add_bias_residual_layer_norm.cpp @@ -34,6 +34,7 @@ AddBiasResidualLayerNormMeta::AddBiasResidualLayerNormMeta( effective_batch_size = ln->effective_batch_size; effective_num_elements = ln->effective_num_elements; profiling = ln->profiling; + inference_debugging = ln->inference_debugging; eps = ln->eps; DataType data_type = ln->data_type; size_t totalSize = effective_batch_size * data_type_size(data_type) * 3; diff --git a/src/ops/add_bias_residual_layer_norm.cu b/src/ops/add_bias_residual_layer_norm.cu index 626e56d64f..ceb1a6514e 100644 --- a/src/ops/add_bias_residual_layer_norm.cu +++ b/src/ops/add_bias_residual_layer_norm.cu @@ -33,6 +33,7 @@ AddBiasResidualLayerNormMeta::AddBiasResidualLayerNormMeta( effective_batch_size = ln->effective_batch_size; effective_num_elements = ln->effective_num_elements; profiling = ln->profiling; + inference_debugging = ln->inference_debugging; eps = ln->eps; DataType data_type = ln->data_type; size_t totalSize = effective_batch_size * data_type_size(data_type) * 3; diff --git a/src/ops/aggregate.cc b/src/ops/aggregate.cc index c7217bb700..67810d3f5b 100644 --- a/src/ops/aggregate.cc +++ b/src/ops/aggregate.cc @@ -235,6 +235,9 @@ OpMeta *Aggregate::init_task(Task const *task, FFHandler handle = *((FFHandler *)task->local_args); AggregateMeta *m = new AggregateMeta(handle, agg->n); m->profiling = agg->profiling; + m->inference_debugging = agg->inference_debugging; + std::strcpy(m->op_name, agg->name); + m->layer_guid = agg->layer_guid; return m; } diff --git a/src/ops/aggregate_spec.cc b/src/ops/aggregate_spec.cc index 5190983148..19b2edc14a 100644 --- a/src/ops/aggregate_spec.cc +++ b/src/ops/aggregate_spec.cc @@ -209,6 +209,9 @@ OpMeta *AggregateSpec::init_task(Task const *task, FFHandler handle = *((FFHandler *)task->local_args); AggregateSpecMeta *m = new AggregateSpecMeta(handle, agg->n); m->profiling = agg->profiling; + m->inference_debugging = agg->inference_debugging; + std::strcpy(m->op_name, agg->name); + m->layer_guid = agg->layer_guid; return m; } diff --git a/src/ops/arg_topk.cc b/src/ops/arg_topk.cc index b877a9f96d..a06b89de07 100644 --- a/src/ops/arg_topk.cc +++ b/src/ops/arg_topk.cc @@ -241,7 +241,10 @@ OpMeta *ArgTopK::init_task(Task const *task, FFHandler handle = *((FFHandler *)task->local_args); ArgTopKMeta *m = new ArgTopKMeta(handle, topk); m->profiling = topk->profiling; + m->inference_debugging = topk->inference_debugging; m->sorted = topk->sorted; + std::strcpy(m->op_name, topk->name); + m->layer_guid = topk->layer_guid; return m; } @@ -308,7 +311,7 @@ InferenceResult InferenceResult ir; return ir; } - ArgTopKMeta const *m = *((ArgTopKMeta **)task->local_args); + ArgTopKMeta *m = *((ArgTopKMeta **)task->local_args); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); @@ -318,6 +321,13 @@ InferenceResult int batch_size = bc->num_active_tokens(); ArgTopK::forward_kernel_wrapper(m, input, indices, batch_size); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + ArgTopK::save_inference_tensors_to_file( + m, shard_id, bc, {input}, {}, {indices}); + } + InferenceResult ir; download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); @@ -332,6 +342,7 @@ void ArgTopK::backward(FFModel const &ff) { void ArgTopK::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->k); sez.serialize(this->sorted); } @@ -341,10 +352,11 @@ Node ArgTopK::deserialize(FFModel &ff, ParallelTensor inputs[], int num_inputs) { assert(num_inputs == 1); - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); int k; bool sorted; dez.deserialize(k); diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index 7863931c82..f336c843e8 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -245,7 +245,10 @@ OpMeta *ArgMax::init_task(Task const *task, length * batch_size, gpu_mem_allocator); m->profiling = s->profiling; + m->inference_debugging = s->inference_debugging; m->beam_search = s->beam_search; + std::strcpy(m->op_name, s->name); + m->layer_guid = s->layer_guid; return m; } @@ -339,7 +342,7 @@ BeamInferenceResult BeamInferenceResult ir; return ir; } - ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args); + ArgMaxMeta *m = *((ArgMaxMeta **)task->local_args); GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); @@ -355,6 +358,14 @@ BeamInferenceResult indices.get_int32_ptr(), ir.token_ids, batch_size); download_tensor(m->probs, ir.probs, batch_size); download_tensor(parent.get_int32_ptr(), ir.parent_id, batch_size); + + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + ArgMax::save_inference_tensors_to_file( + m, shard_id, bc, {}, {}, {input, indices, parent}); + } + return ir; } @@ -365,7 +376,7 @@ InferenceResult Runtime *runtime) { assert(regions.size() == 2); assert(task->regions.size() == 2); - ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args); + ArgMaxMeta *m = *((ArgMaxMeta **)task->local_args); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_tokens == 0) { // Directly return for empty batch config @@ -381,6 +392,12 @@ InferenceResult int batch_size = bc->num_active_tokens(); ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size); InferenceResult ir; + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + ArgMax::save_inference_tensors_to_file( + m, shard_id, bc, {}, {}, {input, indices}); + } download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); return ir; diff --git a/src/ops/attention.cc b/src/ops/attention.cc index 027ea18634..1f71be07a8 100644 --- a/src/ops/attention.cc +++ b/src/ops/attention.cc @@ -521,6 +521,9 @@ OpMeta * MultiHeadAttentionMeta *m = new MultiHeadAttentionMeta(handle, attn, gpu_mem, num_samples, num_heads); m->profiling = attn->profiling; + m->inference_debugging = attn->inference_debugging; + std::strcpy(m->op_name, attn->name); + m->layer_guid = attn->layer_guid; assert(acc_weight.rect.volume() * sizeof(float) == m->weightSize); return m; } diff --git a/src/ops/batch_matmul.cc b/src/ops/batch_matmul.cc index 977c5443b9..f4b06877e5 100644 --- a/src/ops/batch_matmul.cc +++ b/src/ops/batch_matmul.cc @@ -274,8 +274,11 @@ OpMeta *BatchMatmul::init_task(Task const *task, FFHandler handle = *((FFHandler const *)task->local_args); BatchMatmulMeta *m = new BatchMatmulMeta(handle); m->profiling = bmm->profiling; + m->inference_debugging = bmm->inference_debugging; m->a_seq_length_dim = bmm->a_seq_length_dim; m->b_seq_length_dim = bmm->b_seq_length_dim; + std::strcpy(m->op_name, bmm->name); + m->layer_guid = bmm->layer_guid; return m; } diff --git a/src/ops/batch_norm.cpp b/src/ops/batch_norm.cpp index 34a7fcbe72..106e5ebad2 100644 --- a/src/ops/batch_norm.cpp +++ b/src/ops/batch_norm.cpp @@ -293,6 +293,7 @@ BatchNormMeta::BatchNormMeta(FFHandler handler, checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); relu = bn->relu; profiling = bn->profiling; + inference_debugging = bn->inference_debugging; mode = miopenBNSpatial; // #if HIPDNN_VERSION >= 7000 // mode = HIPDNN_BATCHNORM_SPATIAL_PERSISTENT; diff --git a/src/ops/batch_norm.cu b/src/ops/batch_norm.cu index c17244dce0..b77e9d489f 100644 --- a/src/ops/batch_norm.cu +++ b/src/ops/batch_norm.cu @@ -279,6 +279,7 @@ BatchNormMeta::BatchNormMeta(FFHandler handler, checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); relu = bn->relu; profiling = bn->profiling; + inference_debugging = bn->inference_debugging; mode = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7000 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index 93a6de5a8f..2883428254 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -278,6 +278,9 @@ OpMeta *BeamTopK::init_task(Task const *task, MemoryAllocator gpu_mem_allocator(gpu_mem); BeamTopKMeta *m = new BeamTopKMeta(handle, topk, gpu_mem_allocator); m->profiling = topk->profiling; + m->inference_debugging = topk->inference_debugging; + std::strcpy(m->op_name, topk->name); + m->layer_guid = topk->layer_guid; m->sorted = topk->sorted; m->max_beam_width = topk->max_beam_width; m->input_type[0] = topk->inputs[0]->data_type; @@ -346,60 +349,36 @@ BeamInferenceResult assert(regions.size() == 4); assert(task->regions.size() == 4); - // BeamSearchBatchConfig const *bc = (BeamSearchBatchConfig *)task->args; + BeamTopKMeta *m = *((BeamTopKMeta **)task->local_args); BeamSearchBatchConfig const &bc = Future(task->futures[0]).get_result(); - // std::cout << "beam search topk inference: " - // << "\n"; + if (bc.num_tokens == 0) { BeamInferenceResult ir; return ir; } - BeamTopKMeta const *m = *((BeamTopKMeta **)task->local_args); - Domain in1_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - // Domain out1_domain = runtime->get_index_space_domain( - // ctx, task->regions[1].region.get_index_space()); - Domain out2_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - int numdims = in1_domain.get_dim(); - - // float const *in_ptr = helperGetTensorPointerRO( - // regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - // float *value_ptr = helperGetTensorPointerWO( - // regions[1], task->regions[1], FID_DATA, ctx, runtime); - int *index_ptr = helperGetTensorPointerWO( - regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorW index = helperGetGenericTensorAccessorWO( + DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorW value = helperGetGenericTensorAccessorWO( + DT_FLOAT, regions[2], task->regions[2], FID_DATA, ctx, runtime); + GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO( + DT_FLOAT, regions[3], task->regions[3], FID_DATA, ctx, runtime); + + Domain input_domain = runtime->get_index_space_domain( + ctx, task->regions[0].region.get_index_space()); - // ); - float *value_ptr = helperGetTensorPointerWO( - regions[2], task->regions[2], FID_DATA, ctx, runtime); + int *index_ptr = index.get_int32_ptr(); + float *value_ptr = value.get_float_ptr(); + int *parent_ptr = parent.get_int32_ptr(); - int *parent_ptr = helperGetTensorPointerWO( - regions[3], task->regions[3], FID_DATA, ctx, runtime); // embedding size: eg. 4096 - int length = in1_domain.hi()[0] - in1_domain.lo()[0] + 1; - - // int k = out2_domain.hi()[0] - out2_domain.lo()[0] + 1; - + int length = input_domain.hi()[0] - input_domain.lo()[0] + 1; // total token nums - // size_t tokens_per_request = in1_domain.hi()[1] - in1_domain.lo()[1] + 1; - // size_t batch_size = in1_domain.get_volume() / length; size_t batch_size = bc.num_active_tokens(); - // std::vector beam_width; - // std::unordered_map sub_requests = bc->sub_requests; - // for (int i = 0; i < bc->MAX_NUM_REQUESTS; i++) { - // if (bc->request_completed[i]) { - // continue; - // } - // // add beam width for each main request - // beam_width.push_back(sub_requests[i]); - // std::cout << "sub req num: " <(index_ptr, ir.token_ids, batch_size * m->max_beam_width); download_tensor(value_ptr, ir.probs, batch_size * m->max_beam_width); - // if(m->output_type[0] == DT_FLOAT){ - // download_tensor(value.get_float_ptr(), ir.probs, batch_size * - // m->max_beam_width); - // }else if(m->output_type[0] == DT_HALF){ - // download_tensor(value.get_half_ptr(), ir.probs, batch_size * - // m->max_beam_width); - // } download_tensor( parent_ptr, ir.parent_id, batch_size * m->max_beam_width); + + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + BeamTopK::save_inference_tensors_to_file( + m, shard_id, &bc, {input}, {}, {index, value, parent}); + } + return ir; } @@ -435,6 +415,7 @@ void BeamTopK::backward(FFModel const &ff) { void BeamTopK::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->sorted); sez.serialize(this->max_beam_width); } @@ -445,11 +426,12 @@ Node BeamTopK::deserialize(FFModel &ff, int num_inputs) { assert(num_inputs == 1); bool sorted; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; int max_beam_width; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(sorted); dez.deserialize(max_beam_width); BeamTopKParams params; diff --git a/src/ops/cache.cc b/src/ops/cache.cc index 339b2cab55..691e45b559 100644 --- a/src/ops/cache.cc +++ b/src/ops/cache.cc @@ -168,6 +168,9 @@ OpMeta *Cache::init_task(Task const *task, CacheMeta *m = new CacheMeta(handle); m->cache_score = 0.0f; m->profiling = c->profiling; + m->inference_debugging = c->inference_debugging; + std::strcpy(m->op_name, c->name); + m->layer_guid = c->layer_guid; return m; } diff --git a/src/ops/cast.cc b/src/ops/cast.cc index d98a54fe62..2a845cb303 100644 --- a/src/ops/cast.cc +++ b/src/ops/cast.cc @@ -193,6 +193,8 @@ OpMeta *Cast::init_task(Task const *task, CastMeta *m = new CastMeta(handler); m->input_data_type = cast->inputs[0]->data_type; m->output_data_type = cast->outputs[0]->data_type; + std::strcpy(m->op_name, cast->name); + m->layer_guid = cast->layer_guid; return m; } diff --git a/src/ops/concat.cc b/src/ops/concat.cc index 8014d1e145..80935e387b 100644 --- a/src/ops/concat.cc +++ b/src/ops/concat.cc @@ -201,7 +201,9 @@ OpMeta *Concat::init_task(Task const *task, // Note that our internal axis index ordering is opposite to other frameworks init_meta(m, cc->legion_axis); m->profiling = cc->profiling; + m->inference_debugging = cc->inference_debugging; std::strcpy(m->op_name, cc->name); + m->layer_guid = cc->layer_guid; return m; } diff --git a/src/ops/conv_2d.cc b/src/ops/conv_2d.cc index ce7b6ebc01..7d8fd32570 100644 --- a/src/ops/conv_2d.cc +++ b/src/ops/conv_2d.cc @@ -592,8 +592,10 @@ OpMeta *Conv2D::init_task(Task const *task, m->relu = conv->activation == AC_MODE_RELU; m->use_bias = conv->use_bias; m->profiling = conv->profiling; + m->inference_debugging = conv->inference_debugging; m->trainableInputs[0] = conv->trainableInputs[0]; std::strcpy(m->op_name, conv->name); + m->layer_guid = conv->layer_guid; int input_w = acc_input.rect.hi[0] - acc_input.rect.lo[0] + 1; int input_h = acc_input.rect.hi[1] - acc_input.rect.lo[1] + 1; @@ -1013,6 +1015,7 @@ bool Conv2D::estimate_sync_cost(Simulator *sim, void Conv2D::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->out_channels); sez.serialize(this->kernel_h); sez.serialize(this->kernel_w); @@ -1037,10 +1040,11 @@ Node Conv2D::deserialize(FFModel &ff, padding_w, groups; bool use_bias; ActiMode activation; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(out_channels); dez.deserialize(kernel_h); dez.deserialize(kernel_w); diff --git a/src/ops/dropout.cc b/src/ops/dropout.cc index e4021a404e..9b11c9d912 100644 --- a/src/ops/dropout.cc +++ b/src/ops/dropout.cc @@ -170,6 +170,8 @@ OpMeta *Dropout::init_task(Task const *task, .first(); assert(input_domain == output_domain); DropoutMeta *m = new DropoutMeta(handle, dropout, gpu_mem, output_domain); + std::strcpy(m->op_name, dropout->name); + m->layer_guid = dropout->layer_guid; return m; } diff --git a/src/ops/element_binary.cc b/src/ops/element_binary.cc index 21edad11e3..aa31477815 100644 --- a/src/ops/element_binary.cc +++ b/src/ops/element_binary.cc @@ -420,11 +420,13 @@ OpMeta *ElementBinary::init_task(Task const *task, } m->op_type = eb->op_type; m->profiling = eb->profiling; + m->inference_debugging = eb->inference_debugging; m->inplace_a = eb->inplace_a; m->has_same_operands = eb->has_same_operands; m->broadcast_input1 = eb->broadcast_input1; m->broadcast_input2 = eb->broadcast_input2; std::strcpy(m->op_name, eb->name); + m->layer_guid = eb->layer_guid; Domain input1_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); Domain input2_domain, output_domain; @@ -620,7 +622,7 @@ __host__ void return; } // const ElementBinary* ele = (const ElementBinary*) task->args; - ElementBinaryMeta const *m = *((ElementBinaryMeta **)task->local_args); + ElementBinaryMeta *m = *((ElementBinaryMeta **)task->local_args); GenericTensorAccessorR in1, in2; GenericTensorAccessorW out; Domain in1_domain = runtime->get_index_space_domain( @@ -705,8 +707,14 @@ __host__ void runtime); } } - forward_kernel_wrapper(m, in1, in2, out); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector weights_accessors; + ElementBinary::save_inference_tensors_to_file( + m, shard_id, bc, {in1, in2}, {}, {out}); + } } /* @@ -1011,6 +1019,7 @@ bool ElementBinary::measure_operator_cost(Simulator *sim, ElementBinaryMeta *m = new ElementBinaryMeta(sim->handler, this); m->op_type = op_type; m->profiling = this->profiling; + m->inference_debugging = this->inference_debugging; m->inplace_a = this->inplace_a; m->has_same_operands = this->has_same_operands; m->broadcast_input1 = this->broadcast_input1; @@ -1103,6 +1112,7 @@ bool ElementBinary::measure_operator_cost(Simulator *sim, void ElementBinary::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->op_type); sez.serialize(this->inplace_a); } @@ -1115,11 +1125,12 @@ Node ElementBinary::deserialize(FFModel &ff, int num_inputs) { assert(num_inputs == 2); OperatorType op_type; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; bool inplace_a; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(op_type); dez.deserialize(inplace_a); diff --git a/src/ops/element_unary.cc b/src/ops/element_unary.cc index bdb594b0f6..9fb2e6dc1f 100644 --- a/src/ops/element_unary.cc +++ b/src/ops/element_unary.cc @@ -360,9 +360,11 @@ OpMeta *ElementUnary::init_task(Task const *task, // Input and output should have the same data type assert(eu->outputs[0]->data_type == eu->inputs[0]->data_type); m->profiling = eu->profiling; + m->inference_debugging = eu->inference_debugging; m->inplace = eu->inplace; m->scalar = eu->scalar; std::strcpy(m->op_name, eu->name); + m->layer_guid = eu->layer_guid; if (m->inplace) { assert(regions.size() == 1); assert(task->regions.size() == 1); @@ -525,7 +527,7 @@ void ElementUnary::forward_task_with_type( Context ctx, Runtime *runtime) { // const ElementUnary* ele = (const ElementUnary*) task->args; - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); + ElementUnaryMeta *m = *((ElementUnaryMeta **)task->local_args); Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); const DT *input_ptr = NULL; @@ -550,6 +552,27 @@ void ElementUnary::forward_task_with_type( ElementUnary::forward_kernel_wrapper
( m, input_ptr, output_ptr, input_domain.get_volume()); + + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector input_accessors; + std::vector output_accessors; + if (m->inplace) { + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->data_type, regions[0], task->regions[0], FID_DATA, ctx, runtime); + output_accessors.push_back(output); + } else { + GenericTensorAccessorR input = helperGetGenericTensorAccessorWO( + m->data_type, regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->data_type, regions[1], task->regions[1], FID_DATA, ctx, runtime); + input_accessors.push_back(input); + output_accessors.push_back(output); + } + ElementUnary::save_inference_tensors_to_file( + m, shard_id, nullptr, input_accessors, {}, output_accessors); + } } void ElementUnary::backward(FFModel const &ff) { @@ -699,6 +722,7 @@ void ElementUnary::serialize(Legion::Serializer &sez) const { sez.serialize(scalar); sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); } bool ElementUnary::measure_operator_cost(Simulator *sim, @@ -809,10 +833,11 @@ Node ElementUnary::deserialize(FFModel &ff, dez.deserialize(op_type); dez.deserialize(inplace); dez.deserialize(scalar); - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); ElementUnaryParams params; params.op_type = op_type; diff --git a/src/ops/embedding.cc b/src/ops/embedding.cc index 409dcb398e..007e799fe0 100644 --- a/src/ops/embedding.cc +++ b/src/ops/embedding.cc @@ -416,7 +416,10 @@ OpMeta *Embedding::init_task(Task const *task, FFHandler handle = *((FFHandler const *)task->local_args); EmbeddingMeta *m = new EmbeddingMeta(handle, embed); m->profiling = embed->profiling; + m->inference_debugging = embed->inference_debugging; m->aggr = embed->aggr; + std::strcpy(m->op_name, embed->name); + m->layer_guid = embed->layer_guid; return m; } @@ -514,7 +517,7 @@ void Embedding::forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); + EmbeddingMeta *m = *((EmbeddingMeta **)task->local_args); assert(regions.size() == 3); assert(task->regions.size() == 3); // Assert that weight and output must have the same data type @@ -561,75 +564,13 @@ void Embedding::forward_task(Task const *task, } forward_kernel_wrapper( m, input, output, kernel, in_dim, out_dim, effective_batch_size); -} - -#ifdef DEADCODE -template -void Embedding::forward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - // const Embedding* embed = (Embedding*) task->args; - EmbeddingMeta const *m = *((EmbeddingMeta **)task->local_args); - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain kernel_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - if (m->aggr == AGGR_MODE_NONE) { - // assert(kernel_domain.get_dim() == 2); - assert(input_domain.get_dim() + 1 == output_domain.get_dim()); - for (size_t i = 0; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_domain.hi()[i + 1]); - assert(input_domain.lo()[i] == output_domain.lo()[i + 1]); - } - assert(kernel_domain.hi()[0] - kernel_domain.lo()[0] == - output_domain.hi()[0] - output_domain.lo()[0]); - } else { - // assert(kernel_domain.get_dim() == 2); - assert(input_domain.get_dim() == output_domain.get_dim()); - for (size_t i = 1; i < input_domain.get_dim(); i++) { - assert(input_domain.hi()[i] == output_domain.hi()[i]); - assert(input_domain.lo()[i] == output_domain.lo()[i]); - } - assert(kernel_domain.hi()[0] - kernel_domain.lo()[0] == - output_domain.hi()[0] - output_domain.lo()[0]); - } - const TI *input_ptr = helperGetTensorPointerRO( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - float *output_ptr = helperGetTensorPointerWO( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - float const *kernel_ptr = helperGetTensorPointerRO( - regions[2], task->regions[2], FID_DATA, ctx, runtime); - - int in_dim, out_dim, effective_batch_size; - if (m->aggr == AGGR_MODE_NONE) { - in_dim = 1; - out_dim = output_domain.hi()[0] - output_domain.lo()[0] + 1; - effective_batch_size = output_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); - } else { - in_dim = input_domain.hi()[0] - input_domain.lo()[0] + 1; - out_dim = output_domain.hi()[0] - output_domain.lo()[0] + 1; - effective_batch_size = output_domain.get_volume() / out_dim; - assert(effective_batch_size * in_dim == input_domain.get_volume()); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + Embedding::save_inference_tensors_to_file( + m, shard_id, nullptr, {input}, {kernel}, {output}); } - - forward_kernel_wrapper(m, - input_ptr, - output_ptr, - kernel_ptr, - in_dim, - out_dim, - effective_batch_size, - m->aggr, - output_domain.get_volume()); } -#endif void Embedding::backward(FFModel const &ff) { ArgumentMap argmap; diff --git a/src/ops/experts.cc b/src/ops/experts.cc index c8b0ec0f26..6a7d622e51 100644 --- a/src/ops/experts.cc +++ b/src/ops/experts.cc @@ -398,6 +398,7 @@ void Experts::serialize(Legion::Serializer &sez) const { ExpertsParams params = get_params(); sez.serialize(params.layer_guid.id); sez.serialize(params.layer_guid.transformer_layer_id); + sez.serialize(params.layer_guid.model_id); sez.serialize(params.num_experts); sez.serialize(params.experts_start_idx); sez.serialize(params.experts_output_dim_size); @@ -418,10 +419,11 @@ Node Experts::deserialize(FFModel &ff, float alpha; ActiMode activation; bool use_bias; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(num_experts); dez.deserialize(experts_start_idx); dez.deserialize(experts_output_dim_size); @@ -593,6 +595,9 @@ OpMeta *Experts::init_task(Task const *task, exp->use_bias, exp->activation); m->profiling = exp->profiling; + m->inference_debugging = exp->inference_debugging; + std::strcpy(m->op_name, exp->name); + m->layer_guid = exp->layer_guid; return m; } @@ -732,7 +737,7 @@ void Experts::inference_task(Task const *task, Runtime *runtime) { assert(regions.size() == task->regions.size()); - ExpertsMeta const *m = *((ExpertsMeta **)task->local_args); + ExpertsMeta *m = *((ExpertsMeta **)task->local_args); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_tokens == 0) { return; @@ -743,14 +748,19 @@ void Experts::inference_task(Task const *task, assert(regions.size() - 4 == (1 + use_bias)); // get input, indices, topk_gate_preds, outputs - float const *input_ptr = helperGetTensorPointerRO( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - int const *indices_ptr = helperGetTensorPointerRO( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - float const *topk_gate_pred_ptr = helperGetTensorPointerRO( - regions[2], task->regions[2], FID_DATA, ctx, runtime); - float *output_ptr = helperGetTensorPointerWO( - regions[3], task->regions[3], FID_DATA, ctx, runtime); + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + DT_FLOAT, regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorR indices = helperGetGenericTensorAccessorRO( + DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorR topk_gate_preds = helperGetGenericTensorAccessorRO( + DT_FLOAT, regions[2], task->regions[2], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + DT_FLOAT, regions[3], task->regions[3], FID_DATA, ctx, runtime); + + float const *input_ptr = input.get_float_ptr(); + int const *indices_ptr = indices.get_int32_ptr(); + float const *topk_gate_pred_ptr = topk_gate_preds.get_float_ptr(); + float *output_ptr = output.get_float_ptr(); assert(input_ptr != nullptr && indices_ptr != nullptr && topk_gate_pred_ptr != nullptr && output_ptr != nullptr); @@ -1107,6 +1117,13 @@ void Experts::inference_task(Task const *task, free(cpu_output_ptr); } #endif + + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + Experts::save_inference_tensors_to_file( + m, shard_id, bc, {input, indices, topk_gate_preds}, {}, {output}); + } } void Experts::forward_task(Task const *task, diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 5f2874e662..7d0d5841f0 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -1088,6 +1088,31 @@ __host__ void assert(false && "Fusion currently does not support type"); } } + if (metas->meta[op]->inference_debugging) { + std::vector input_accessors_to_save; + std::vector weight_accessors_to_save; + std::vector output_accessors_to_save; + for (int i = 0; i < fused->op_num_inputs[op]; i++) { + int my_off = fused->op_input_idx[i + ioff]; + if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { + input_accessors_to_save.push_back(input_accessor[my_off]); + } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { + input_accessors_to_save.push_back(output_accessor[my_off]); + } else { + assert(false); + } + } + for (int i = 0; i < fused->op_num_weights[op]; i++) { + assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); + weight_accessors_to_save.push_back(weight_accessor[fused->op_weight_idx[i + woff]]); + } + for (int i = 0; i < fused->op_num_outputs[op]; i++) { + output_accessors_to_save.push_back(output_accessor[i + ooff]); + } + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + FusedOp::save_inference_tensors_to_file(metas->meta[op], shard_id, bc, input_accessors_to_save, weight_accessors_to_save, output_accessors_to_save); + } ioff += fused->op_num_inputs[op]; woff += fused->op_num_weights[op]; ooff += fused->op_num_outputs[op]; diff --git a/src/ops/gather.cc b/src/ops/gather.cc index 635c741d8b..d7c1dee44c 100644 --- a/src/ops/gather.cc +++ b/src/ops/gather.cc @@ -167,6 +167,7 @@ void Gather::serialize(Legion::Serializer &sez) const { sez.serialize(params.legion_dim); sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); } using PCG::Node; @@ -178,10 +179,11 @@ Node Gather::deserialize(FFModel &ff, assert(num_inputs == 2); int legion_dim; dez.deserialize(legion_dim); - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); GatherParams params; params.legion_dim = legion_dim; @@ -243,6 +245,8 @@ OpMeta *Gather::init_task(Task const *task, Gather const *gather = (Gather const *)task->args; FFHandler handle = *((FFHandler const *)task->local_args); GatherMeta *m = new GatherMeta(handle, gather); + std::strcpy(m->op_name, gather->name); + m->layer_guid = gather->layer_guid; GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorR index = helperGetGenericTensorAccessorRO( diff --git a/src/ops/group_by.cc b/src/ops/group_by.cc index f2f94234c3..50871983f5 100644 --- a/src/ops/group_by.cc +++ b/src/ops/group_by.cc @@ -266,6 +266,9 @@ OpMeta *Group_by::init_task(Task const *task, FFHandler handle = *((FFHandler *)task->local_args); GroupByMeta *m = new GroupByMeta(handle, gb->n, gb->alpha); m->profiling = gb->profiling; + m->inference_debugging = gb->inference_debugging; + std::strcpy(m->op_name, gb->name); + m->layer_guid = gb->layer_guid; return m; } @@ -369,35 +372,39 @@ void Group_by::forward_task(Task const *task, int n = (int)regions.size() - 2; assert((int)task->regions.size() == n + 2); - GroupByMeta const *m = *((GroupByMeta **)task->local_args); + GroupByMeta *m = *((GroupByMeta **)task->local_args); // get input and assign regions. Each tensor has three dimensions: // (datapoint_dim, batch_size, replica_dim) - AccessorRO const acc_input(regions[0], FID_DATA); - AccessorRO const acc_assign(regions[1], FID_DATA); - - Rect<3> rect_input = runtime->get_index_space_domain( + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + DT_FLOAT, regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorR assign = helperGetGenericTensorAccessorRO( + DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); + Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - Rect<3> rect_assign = runtime->get_index_space_domain( + Domain assign_domain = runtime->get_index_space_domain( ctx, task->regions[1].region.get_index_space()); - coord_t input_rows = rect_input.hi[1] - rect_input.lo[1] + 1; - coord_t input_cols = rect_input.hi[0] - rect_input.lo[0] + 1; - assert(input_rows == rect_assign.hi[1] - rect_assign.lo[1] + 1); + coord_t input_rows = input_domain.hi()[1] - input_domain.lo()[1] + 1; + coord_t input_cols = input_domain.hi()[0] - input_domain.lo()[0] + 1; + assert(input_rows == assign_domain.hi()[1] - assign_domain.lo()[1] + 1); - int k = rect_assign.hi[0] - rect_assign.lo[0] + 1; + int k = assign_domain.hi()[0] - assign_domain.lo()[0] + 1; int batch_size = input_rows; int data_dim = input_cols; // Create a vector of n outputs, where n is the number of experts. // Each entry in the "outputs" vector points to the Legion tensor that will // contain the tockens dispatched to the corresponding expert + std::vector output_accessors; float *outputs[n]; for (int i = 0; i < n; i++) { + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + DT_FLOAT, regions[i + 2], task->regions[i + 2], FID_DATA, ctx, runtime); + output_accessors.push_back(output); Domain out_domain = runtime->get_index_space_domain( ctx, task->regions[i + 2].region.get_index_space()); - outputs[i] = helperGetTensorPointerWO( - regions[i + 2], task->regions[i + 2], FID_DATA, ctx, runtime); + outputs[i] = output.get_float_ptr(); coord_t output_rows = out_domain.hi()[1] - out_domain.lo()[1] + 1; coord_t output_cols = out_domain.hi()[0] - out_domain.lo()[0] + 1; @@ -405,13 +412,19 @@ void Group_by::forward_task(Task const *task, } Group_by::forward_kernel_wrapper(m, - acc_input.ptr(rect_input), - acc_assign.ptr(rect_assign), + input.get_float_ptr(), + assign.get_int32_ptr(), outputs, n, k, batch_size, data_dim); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + Group_by::save_inference_tensors_to_file( + m, shard_id, nullptr, {input, assign}, {}, output_accessors); + } } void Group_by::backward(FFModel const &ff) { diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index 68b5fa39a1..2f72976d30 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -23,10 +23,6 @@ #endif #include "flexflow/utils/hash_utils.h" #include "legion/legion_utilities.h" -#ifdef INFERENCE_TESTS -#include -using namespace at::indexing; -#endif namespace FlexFlow { @@ -725,6 +721,9 @@ OpMeta *IncMultiHeadSelfAttention::init_task( gpu_mem_allocator.reserved_total_size); } m->profiling = attn->profiling; + m->inference_debugging = attn->inference_debugging; + std::strcpy(m->op_name, attn->name); + m->layer_guid = attn->layer_guid; if (attn->quantization_type == DT_NONE) { assert(weight.domain.get_volume() * data_type_size(weight.data_type) == m->weightSize); @@ -811,7 +810,6 @@ void IncMultiHeadSelfAttention::inference_task( assert(task->regions.size() == regions.size()); - // BatchConfig const *bc = (BatchConfig *)task->args; BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); log_inc_mha.debug("BatchConfig, num_tokens: %d, num_requests: %d", bc->num_tokens, @@ -820,7 +818,7 @@ void IncMultiHeadSelfAttention::inference_task( return; } - IncMultiHeadSelfAttentionMeta const *m = + IncMultiHeadSelfAttentionMeta *m = *((IncMultiHeadSelfAttentionMeta **)task->local_args); assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 @@ -860,787 +858,18 @@ void IncMultiHeadSelfAttention::inference_task( IncMultiHeadSelfAttention::inference_kernel_wrapper( m, bc, task->index_point.point_data[0], input, weight, output, biases); -#ifdef INFERENCE_TESTS - printf("Checking IncMultiHeadSelfAttention computations...\n"); - - // ============================================================================= - // Define helper functions to handle row-major arrays - // ============================================================================= - - auto set_value_row_major = [](float *arr, - std::vector const &shape, - std::vector const &indices, - float value) -> void { - int offset = 0; - for (int i = 0; i < shape.size(); i++) { - int index = indices[i]; - int stride = 1; - for (int j = i + 1; j < shape.size(); j++) { - stride *= shape[j]; - } - offset += index * stride; - } - *(arr + offset) = value; - }; - - // ============================================================================= - // Load input/output/weights and parse general configs - // ============================================================================= - - float *input_cpu = - download_tensor(input.get_float_ptr(), input_domain.get_volume()); - assert(input_cpu != nullptr); - float *weight_cpu = download_tensor(weight.get_float_ptr(), - weight_domain.get_volume()); - assert(weight_cpu != nullptr); - float *output_cpu = download_tensor(output.get_float_ptr(), - output_domain.get_volume()); - assert(output_cpu != nullptr); - - // Input tensor dimensions - coord_t data_dim = input_domain.hi()[0] - input_domain.lo()[0] + 1; - coord_t max_sequence_length = input_domain.hi()[1] - input_domain.lo()[1] + 1; - coord_t batch_size = input_domain.hi()[2] - input_domain.lo()[2] + 1; - coord_t replica_dim = input_domain.hi()[3] - input_domain.lo()[3] + 1; - assert(replica_dim == 1); - - size_t effective_batch_size = max_sequence_length * batch_size; - float inputs_arr[data_dim][effective_batch_size] = {0}; - for (size_t i = 0; i < data_dim * bc->num_active_tokens(); i++) { - size_t data_index = i % data_dim; - size_t token_index = i / data_dim; - assert(data_index < data_dim); - assert(token_index < effective_batch_size); - inputs_arr[data_index][token_index] = input_cpu[i]; - } - torch::Tensor torch_input = torch::from_blob( - inputs_arr, {data_dim, (long int)effective_batch_size}, torch::kFloat32); - - // Weight tensor dimensions - coord_t all_weight_params = weight_domain.hi()[0] - weight_domain.lo()[0] + 1; - coord_t num_q_heads = weight_domain.hi()[1] - weight_domain.lo()[1] + 1; - replica_dim = weight_domain.hi()[2] - weight_domain.lo()[2] + 1; - size_t qParas = m->qProjSize * m->qSize; - size_t kParas = m->kProjSize * m->kSize; - size_t vParas = m->vProjSize * m->vSize; - size_t oParas = m->oProjSize * (m->vProjSize > 0 ? m->vProjSize : m->vSize); - - assert(all_weight_params == qParas + kParas + vParas + oParas); - assert(num_q_heads == m->num_q_heads); - assert(replica_dim == 1); - - assert(m->qSize == m->kSize && m->kSize == m->vSize); - // printf("m->qSize: %i\n", m->qSize); - // keep things simple for now - assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); - long int proj_sum = m->qProjSize + m->kProjSize + m->vProjSize; - // load weight manually because Torch can't easily read a tensor serialized in - // column-major order. - - // printf("m->kProjSize: %i, BatchConfig::max_tokens_per_batch(): %i, " - // "bc->num_active_tokens(): %i, num_q_heads: %lli, - // BatchConfig::max_requests_per_batch(): %i, " - // "bc->num_active_requests(): %i\n", m->kProjSize, - // BatchConfig::max_tokens_per_batch(), bc->num_active_tokens(), - // num_q_heads, BatchConfig::max_requests_per_batch(), - // bc->num_active_requests()); - // for (int t=0; t < bc->num_active_tokens(); t++) { - // printf("token %i has request_index: %li and token_position: %li\n", - // t, bc->token2ids.token_indexes[t].request_index, - // bc->token2ids.token_indexes[t].token_position); - // } - - // ============================================================================= - // Load the output tensor (with CUDA results), and create a Torch tensor - // ============================================================================= - - float output_cuda[m->oProjSize][effective_batch_size] = {0}; - for (int i = 0; i < m->oProjSize * effective_batch_size; i++) { - int row_idx = i % m->oProjSize; - int col_idx = i / m->oProjSize; - assert(row_idx < m->oProjSize && col_idx < effective_batch_size); - output_cuda[row_idx][col_idx] = output_cpu[i]; - } - torch::Tensor torch_out_cuda = - torch::from_blob(output_cuda, - {m->oProjSize, (int64_t)effective_batch_size}, - torch::kFloat32); - - // ============================================================================= - // Load the Q/K/V projection weights, and create a Torch tensor - // ============================================================================= - std::vector w_qkv_shape = {m->qSize, m->qProjSize, 3, (int)num_q_heads}; - float *w_qkv = - (float *)calloc(m->qSize * m->qProjSize * 3 * num_q_heads, sizeof(float)); - assert(w_qkv[0] == 0.0f); - - for (int h = 0; h < num_q_heads; h++) { - for (size_t i = 0; i < m->qProjSize * m->qSize; i++) { - int row_index = i % m->qSize; - int column_index = i / m->qSize; - // Q - set_value_row_major(w_qkv, - w_qkv_shape, - {row_index, column_index, 0, h}, - weight_cpu[all_weight_params * h + - m->qSize * column_index + row_index]); - // K - set_value_row_major( - w_qkv, - w_qkv_shape, - {row_index, column_index, 1, h}, - weight_cpu[all_weight_params * h + m->qProjSize * m->qSize + - m->qSize * column_index + row_index]); - // V - set_value_row_major( - w_qkv, - w_qkv_shape, - {row_index, column_index, 2, h}, - weight_cpu[all_weight_params * h + 2 * m->qProjSize * m->qSize + - m->qSize * column_index + row_index]); - } - } - // convert weights to torch tensor - torch::Tensor torch_w_qkv = torch::from_blob( - w_qkv, {m->qSize, m->qProjSize, 3, (int)num_q_heads}, torch::kFloat32); - - /* std::cout << "Torch projection weights size: " << torch_w_qkv.sizes() - << std::endl; - std::cout << "Torch input size: " << torch_input.sizes() << std::endl; - std::cout << "Number of active tokens: " << bc->num_active_tokens() - << std::endl; */ - // std::cout << "torch_w_qkv:" << std::endl << torch_w_qkv << std::endl; - - // ============================================================================= - // Compute the Q/K/V projections, and compare the results with CUDA - // ============================================================================= - // ----------------------- C++ computations & checks ------------------------ - torch::Tensor qkv_projs = torch::einsum( - "ijkl,im->jmkl", - {torch_w_qkv, - torch_input.index({Slice(), Slice(0, bc->num_active_tokens())})}); - // std::cout << "qkv_projs size: " << qkv_projs.sizes() << std::endl; - assert(qkv_projs.sizes()[0] == m->qProjSize); - assert(qkv_projs.sizes()[1] == bc->num_active_tokens() && - qkv_projs.sizes()[1] <= effective_batch_size); - assert(qkv_projs.sizes()[2] == 3); - assert(qkv_projs.sizes()[3] == num_q_heads); - free(w_qkv); - - // ----------------------- Loading CUDA results for this step --------------- - float *QKVProjArray_cpu = download_tensor( - m->devQKVProjArray, - BatchConfig::max_tokens_per_batch() * proj_sum * m->num_q_heads); - assert(QKVProjArray_cpu != nullptr); - - std::vector QKVProjArray_converted_shape = { - m->qProjSize, bc->num_active_tokens(), 3, (int)num_q_heads}; - float *QKVProjArray_converted = (float *)calloc( - m->qProjSize * bc->num_active_tokens() * 3 * num_q_heads, sizeof(float)); - - // skip over padding at the end of QKVProjArray_cpu - // convert from column order to 3D matrix because torch cannot automatically - // import matrices flattened in column order - for (size_t i = 0; i < proj_sum * bc->num_active_tokens() * num_q_heads; - i++) { - int proj_size_index = i % m->qProjSize; - int head_index = i / (proj_sum * bc->num_active_tokens()); - int token_index = - ((i - head_index * proj_sum * bc->num_active_tokens()) / m->qProjSize) % - bc->num_active_tokens(); - int qkv_offset = (i - head_index * proj_sum * bc->num_active_tokens()) / - (m->qProjSize * bc->num_active_tokens()); - assert(proj_size_index < proj_sum); - assert(head_index < num_q_heads); - assert(token_index < bc->num_active_tokens()); - assert(qkv_offset < 3); - set_value_row_major(QKVProjArray_converted, - QKVProjArray_converted_shape, - {proj_size_index, token_index, qkv_offset, head_index}, - QKVProjArray_cpu[i]); - } - torch::Tensor QKVProjArray_torch = - torch::from_blob(QKVProjArray_converted, - {m->qProjSize, bc->num_active_tokens(), 3, num_q_heads}, - torch::kFloat32); - - // ----------------------- Comparing C++ & CUDA results --------------------- - // std::cout << "QKVProjArray_torch" << std::endl; - // for (int i=0; inum_active_tokens(); t++) { - for (size_t d = 0; d < m->kProjSize; d++) { - size_t kcache_idx = d * MAX_SEQ_LEN * m->num_q_heads * - BatchConfig::max_requests_per_batch() + - bc->tokensInfo[t].abs_depth_in_request * - m->num_q_heads * - BatchConfig::max_requests_per_batch() + - h * BatchConfig::max_requests_per_batch() + - bc->tokensInfo[t].request_index; - m->kcache[kcache_idx] = - qkv_projs.index({(int64_t)d, (int64_t)t, 1, (int64_t)h}) - .item(); - } - for (size_t d = 0; d < m->vProjSize; d++) { - size_t vcache_idx = d * MAX_SEQ_LEN * m->num_q_heads * - BatchConfig::max_requests_per_batch() + - bc->tokensInfo[t].abs_depth_in_request * - m->num_q_heads * - BatchConfig::max_requests_per_batch() + - h * BatchConfig::max_requests_per_batch() + - bc->tokensInfo[t].request_index; - m->vcache[vcache_idx] = - qkv_projs.index({(int64_t)d, (int64_t)t, 2, (int64_t)h}) - .item(); - } - } - } - // Create torch tensors from the arrays - torch::Tensor K_t = torch::from_blob(m->kcache, - {m->kProjSize, - MAX_SEQ_LEN, - num_q_heads, - BatchConfig::max_requests_per_batch()}, - torch::kFloat32); - torch::Tensor V_t = torch::from_blob(m->vcache, - {m->vProjSize, - MAX_SEQ_LEN, - num_q_heads, - BatchConfig::max_requests_per_batch()}, - torch::kFloat32); - - // Compute useful indices - std::vector req_idxs; - std::vector r_first_idx; - std::vector r_num_tokens; - for (size_t t = 0; t < bc->num_active_tokens(); t++) { - size_t rid = bc->tokensInfo[t].request_index; - if (req_idxs.size() == 0 || req_idxs[req_idxs.size() - 1] != rid) { - req_idxs.push_back(rid); - r_first_idx.push_back(t); - r_num_tokens.push_back(1); - } else { - r_num_tokens[r_num_tokens.size() - 1]++; + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector weights_accessors; + weights_accessors.push_back(weight); + if (*m->qkv_bias || *m->final_bias) { + weights_accessors.push_back(biases); } - assert(req_idxs.size() == r_first_idx.size() && - r_first_idx.size() == r_num_tokens.size()); + IncMultiHeadSelfAttention::save_inference_tensors_to_file( + m, shard_id, bc, {input}, weights_accessors, {output}); } - assert(req_idxs.size() == bc->num_active_requests()); - assert(std::accumulate(r_num_tokens.begin(), - r_num_tokens.end(), - decltype(r_num_tokens)::value_type(0)) == - bc->num_active_tokens()); - - // ----------------------- Loading CUDA results for this step --------------- - float *keyCache_cpu = download_tensor( - m->keyCache, - m->num_q_heads * m->kProjSize * BatchConfig::max_requests_per_batch() * - MAX_SEQ_LEN); - float *valueCache_cpu = download_tensor( - m->valueCache, - m->num_q_heads * m->vProjSize * BatchConfig::max_requests_per_batch() * - MAX_SEQ_LEN); - assert(keyCache_cpu != nullptr); - assert(valueCache_cpu != nullptr); - - float *kcache_cuda = - (float *)calloc(m->kProjSize * MAX_SEQ_LEN * m->num_q_heads * - BatchConfig::max_requests_per_batch(), - sizeof(float)); - float *vcache_cuda = - (float *)calloc(m->vProjSize * MAX_SEQ_LEN * m->num_q_heads * - BatchConfig::max_requests_per_batch(), - sizeof(float)); - int index = 0; - for (int i = 0; i < m->kProjSize; i++) { - for (int j = 0; j < MAX_SEQ_LEN; j++) { - for (int k = 0; k < m->num_q_heads; k++) { - for (int l = 0; l < BatchConfig::max_requests_per_batch(); l++) { - int col_major_index = - l * m->kProjSize * MAX_SEQ_LEN * m->num_q_heads + - k * m->kProjSize * MAX_SEQ_LEN + j * m->kProjSize + i; - kcache_cuda[index++] = keyCache_cpu[col_major_index]; - } - } - } - } - index = 0; - for (int i = 0; i < m->vProjSize; i++) { - for (int j = 0; j < MAX_SEQ_LEN; j++) { - for (int k = 0; k < m->num_q_heads; k++) { - for (int l = 0; l < BatchConfig::max_requests_per_batch(); l++) { - int col_major_index = - l * m->vProjSize * MAX_SEQ_LEN * m->num_q_heads + - k * m->vProjSize * MAX_SEQ_LEN + j * m->vProjSize + i; - vcache_cuda[index++] = valueCache_cpu[col_major_index]; - } - } - } - } - torch::Tensor K_t_cuda = - torch::from_blob(kcache_cuda, - {m->kProjSize, - MAX_SEQ_LEN, - num_q_heads, - BatchConfig::max_requests_per_batch()}, - torch::kFloat32); - torch::Tensor V_t_cuda = - torch::from_blob(vcache_cuda, - {m->vProjSize, - MAX_SEQ_LEN, - num_q_heads, - BatchConfig::max_requests_per_batch()}, - torch::kFloat32); - - // ----------------------- Comparing C++ & CUDA results --------------------- - - // std::cout << "kcache differences:" << std::endl; - // for (int i=0; i < bc->num_active_requests() + 1; i++) { - // for (int j=0; j < num_q_heads; j++) { - // for (int l=0; l < m->kProjSize; l++) { - // for (int k=0; k < MAX_SEQ_LEN; k++) { - // size_t kcache_idx = - // l * MAX_SEQ_LEN * num_q_heads * - // BatchConfig::max_requests_per_batch() + k * num_q_heads * - // BatchConfig::max_requests_per_batch() + j * - // BatchConfig::max_requests_per_batch() + i; if ( - // abs(m->kcache[kcache_idx] - keyCache_cpu[ - // i * m->kProjSize * MAX_SEQ_LEN * num_q_heads + - // j * m->kProjSize * MAX_SEQ_LEN + - // k * m->kProjSize + - // l - // ]) > 0.00001) { - // printf("req: %i (rid: %i), head: %i, data_dim: %i, token_pos: - // %i\n", - // i, req_idxs[i], j, l, k); - // } - // } - // } - // } - // } - - // std::cout << "keyCache from CUDA:" << std::endl; - // for (int i=0; inum_active_requests()+1; i++) { - // for (int j=0; jkProjSize; l++) { - // for (int k=0; k< MAX_SEQ_LEN; k++) { - // printf("%f ", - // keyCache_cpu[i * m->kProjSize * MAX_SEQ_LEN * num_q_heads + - // j * m->kProjSize * MAX_SEQ_LEN + - // k * m->kProjSize + - // l - // ]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // } - - // std::cout << "valueCache from CUDA:" << std::endl; - // for (int i=0; inum_active_requests()+1; i++) { - // for (int j=0; jvProjSize; l++) { - // for (int k=0; k< MAX_SEQ_LEN; k++) { - // printf("%f ", - // valueCache_cpu[ - // i * m->vProjSize * MAX_SEQ_LEN * num_q_heads + - // j * m->vProjSize * MAX_SEQ_LEN + - // k * m->vProjSize + - // l]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // } - - // printf("\n"); - - // std::cout << "C++ kcache:" << std::endl; - // for (int i=0; inum_active_requests()+1; i++) { - // for (int j=0; j < num_q_heads; j++) { - // for (int l=0; l < m->kProjSize; l++) { - // for (int k=0; k < MAX_SEQ_LEN; k++) { - // size_t kcache_idx = - // l * MAX_SEQ_LEN * num_q_heads * - // BatchConfig::max_requests_per_batch() + k * num_q_heads * - // BatchConfig::max_requests_per_batch() + j * - // BatchConfig::max_requests_per_batch() + i; - // printf("%f ", m->kcache[kcache_idx]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // } - - // std::cout << "C++ vcache:" << std::endl; - // for (int i=0; inum_active_requests()+1; i++) { - // for (int j=0; jvProjSize; l++) { - // for (int k=0; k< MAX_SEQ_LEN; k++) { - // size_t vcache_idx = - // l * MAX_SEQ_LEN * num_q_heads * - // BatchConfig::max_requests_per_batch() - // + k * num_q_heads * BatchConfig::max_requests_per_batch() + j - // * BatchConfig::max_requests_per_batch() + i; - // printf("%f ", m->vcache[vcache_idx]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // } - - assert(torch::allclose(K_t_cuda, K_t, 1e-05, 1e-05)); - assert(torch::allclose(V_t_cuda, V_t, 1e-05, 1e-05)); - free(kcache_cuda); - free(vcache_cuda); - - // ============================================================================= - // Load the W_out projection weights - // ============================================================================= - - // ----------------------- C++ operations & checks -------------------------- - float *w_out = (float *)calloc(m->vProjSize * m->num_q_heads * m->oProjSize, - sizeof(float)); - std::vector w_out_shape = {m->vProjSize, m->num_q_heads, m->oProjSize}; - assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); - for (int h = 0; h < num_q_heads; h++) { - for (int v = 0; v < m->vProjSize; v++) { - for (int o = 0; o < m->oProjSize; o++) { - set_value_row_major( - w_out, - w_out_shape, - {v, h, o}, - weight_cpu[all_weight_params * h + 3 * m->qProjSize * m->qSize + - m->vProjSize * o + v]); - } - } - } - // convert weights to torch tensor - torch::Tensor torch_w_out = torch::from_blob( - w_out, {m->vProjSize, m->num_q_heads, m->oProjSize}, torch::kFloat32); - - // ----------------------- Loading CUDA results for this step --------------- - float *w_out_cuda = download_tensor( - m->W_out_contiguous, m->vProjSize * m->oProjSize * m->num_q_heads); - assert(w_out_cuda != nullptr); - float *converted_wout_tensor = (float *)calloc( - m->vProjSize * m->num_q_heads * m->oProjSize, sizeof(float)); - std::vector converted_wout_tensor_shape = { - m->vProjSize, m->num_q_heads, m->oProjSize}; - - for (int i = 0; i < m->vProjSize * m->num_q_heads * m->oProjSize; i++) { - int v_idx = i % m->vProjSize; - int h_idx = (i / m->vProjSize) % m->num_q_heads; - int o_idx = i / (m->vProjSize * m->num_q_heads); - assert(v_idx < m->vProjSize && h_idx < m->num_q_heads && - o_idx < m->oProjSize); - set_value_row_major(converted_wout_tensor, - converted_wout_tensor_shape, - {v_idx, h_idx, o_idx}, - w_out_cuda[i]); - } - torch::Tensor w_out_cuda_tensor = - torch::from_blob(converted_wout_tensor, - {m->vProjSize, m->num_q_heads, m->oProjSize}, - torch::kFloat32); - - // ----------------------- Comparing C++ & CUDA results --------------------- - assert(torch::allclose(w_out_cuda_tensor, torch_w_out, 1e-05, 1e-05)); - free(converted_wout_tensor); - - // ============================================================================= - // Compute the softmax(QK^T/sqrt(d_k))V product, request by request - // ============================================================================= - - // ----------------------- C++ initialization steps ------------------------- - torch::Tensor Q_projs = qkv_projs.index({Slice(), Slice(), 0, Slice()}) - .reshape({qkv_projs.sizes()[0], - qkv_projs.sizes()[1], - qkv_projs.sizes()[3]}); - - torch::Tensor qk_products[bc->num_active_requests()]; - torch::Tensor qk_softmax[bc->num_active_requests()]; - torch::Tensor attn_heads[bc->num_active_requests()]; - - torch::Tensor cpp_output = - torch::zeros({m->oProjSize, bc->num_active_tokens()}); - - // ----------------------- Loading CUDA results for this step --------------- - float *qk_prods_cpu = download_tensor( - m->qk_prods, - BatchConfig::max_tokens_per_batch() * - BatchConfig::max_tokens_per_batch() * num_q_heads); - assert(qk_prods_cpu != nullptr); - - float *qk_prods_softmax_cpu = download_tensor( - m->qk_prods_softmax, - BatchConfig::max_tokens_per_batch() * - BatchConfig::max_tokens_per_batch() * num_q_heads); - assert(qk_prods_softmax_cpu != nullptr); - - float *attn_heads_cpu = download_tensor( - m->attn_heads, - BatchConfig::max_tokens_per_batch() * m->num_q_heads * m->vProjSize); - assert(attn_heads_cpu != nullptr); - - // ----------------------- Main loop (request by request) ------------------- - size_t qk_prods_cpu_offset = 0; - - for (size_t r = 0; r < bc->num_active_requests(); r++) { - // Compute pre-request parameters - size_t num_new_tokens = r_num_tokens[r]; - int64_t rid = (int64_t)(req_idxs[r]); - int64_t num_tokens_received_so_far = - (int64_t)(bc->requestsInfo[rid].token_start_offset + - bc->requestsInfo[rid].num_tokens_in_batch); - assert(num_new_tokens == bc->requestsInfo[rid].num_tokens_in_batch); - assert(num_tokens_received_so_far >= (int64_t)num_new_tokens); - - // ----------------------- C++ computations ------------------------------- - // Get the slice of the Q projection tensor with the tokens in the current - // request - torch::Tensor Q_req = - Q_projs.index({Slice(), - Slice(r_first_idx[r], r_first_idx[r] + num_new_tokens), - Slice()}); - // std::cout << "Q_req.sizes(): " << Q_req.sizes() << std::endl; - assert(Q_req.sizes()[0] == m->qProjSize); - assert(Q_req.sizes()[1] == num_new_tokens); - assert(Q_req.sizes()[2] == num_q_heads); - - /*printf("\n------------ QK multiplication (C++) -------------\n"); - printf("Request r=%lu. num_new_tokens: %lu, num_tokens_received_so_far: %li, - rid: %li, Qproj slice: (%i, %i)\n", r, num_new_tokens, - num_tokens_received_so_far, rid, r_first_idx[r], r_first_idx[r] + - num_new_tokens); - - std::cout << "Q_req matrix (idk dims):" << std::endl << - Q_req.index({Slice(), Slice(), 0}) << std::endl << std::endl; std::cout << - "K_t matrix (ilk dims):" << std::endl << K_t.index({Slice(), Slice(0, - num_tokens_received_so_far), 0, rid}) << std::endl << std::endl; std::cout - << "C++ alpha: " << (1.0f / sqrt(m->kProjSize)) << std::endl;*/ - - // Compute (Q*K^T)/sqrt(d_k) matmul - qk_products[r] = - torch::einsum("ijk,ilk->jlk", - {Q_req, - K_t.index({Slice(), - Slice(0, num_tokens_received_so_far), - Slice(), - rid})}) * - (1.0f / sqrt(m->kProjSize)); - - // Set entries above diagonal to -inf to make attention causal. - for (int h = 0; h < num_q_heads; h++) { - qk_products[r].index( - {Slice(), Slice(num_tokens_received_so_far - num_new_tokens), h}) = - qk_products[r] - .index({Slice(), - Slice(num_tokens_received_so_far - num_new_tokens), - h}) - .tril() + - torch::full({(int64_t)num_new_tokens, (int64_t)num_new_tokens}, - -INFINITY) - .triu() - .fill_diagonal_(0); - } - // Compute softmax for each request block - qk_softmax[r] = torch::softmax(qk_products[r], -2); - assert(qk_softmax[r].sizes()[0] == num_new_tokens); - assert(qk_softmax[r].sizes()[1] == num_tokens_received_so_far); - assert(qk_softmax[r].sizes()[2] == m->num_q_heads); - - // ------------------- Loading CUDA results for this step --------------- - float *converted_qk_prod = (float *)calloc( - num_new_tokens * num_tokens_received_so_far * num_q_heads, - sizeof(float)); - float *converted_qk_prod_softmax = (float *)calloc( - num_new_tokens * num_tokens_received_so_far * num_q_heads, - sizeof(float)); - std::vector converted_qk_prod_shape = { - (int)num_new_tokens, (int)num_tokens_received_so_far, (int)num_q_heads}; - - for (size_t i = 0; - i < num_new_tokens * num_tokens_received_so_far * num_q_heads; - i++) { - size_t new_t_idx = i % num_new_tokens; - size_t all_t_idx = (i / num_new_tokens) % num_tokens_received_so_far; - size_t head_idx = i / (num_new_tokens * num_tokens_received_so_far); - assert(new_t_idx < num_new_tokens && - all_t_idx < num_tokens_received_so_far && head_idx < num_q_heads); - set_value_row_major(converted_qk_prod, - converted_qk_prod_shape, - {(int)new_t_idx, (int)all_t_idx, (int)head_idx}, - qk_prods_cpu[i + qk_prods_cpu_offset]); - set_value_row_major(converted_qk_prod_softmax, - converted_qk_prod_shape, - {(int)new_t_idx, (int)all_t_idx, (int)head_idx}, - qk_prods_softmax_cpu[i + qk_prods_cpu_offset]); - } - torch::Tensor qk_prods_cuda = torch::from_blob( - converted_qk_prod, - {(int64_t)num_new_tokens, num_tokens_received_so_far, num_q_heads}, - torch::kFloat32); - torch::Tensor qk_prods_softmax_cuda = torch::from_blob( - converted_qk_prod_softmax, - {(int64_t)num_new_tokens, num_tokens_received_so_far, num_q_heads}, - torch::kFloat32); - - // ------------------- Comparing C++ & CUDA results ------------------ - /* std::cout << "C++:" <vProjSize); - assert( - V_t.index({Slice(), Slice(0, num_tokens_received_so_far), Slice(), rid}) - .sizes()[1] == num_tokens_received_so_far); - assert( - V_t.index({Slice(), Slice(0, num_tokens_received_so_far), Slice(), rid}) - .sizes()[2] == m->num_q_heads); - attn_heads[r] = torch::einsum( - "ijk,ljk->ilk", - {qk_softmax[r], - V_t.index( - {Slice(), Slice(0, num_tokens_received_so_far), Slice(), rid})}); - assert(attn_heads[r].sizes()[0] == num_new_tokens); - assert(attn_heads[r].sizes()[1] == m->vProjSize); - assert(attn_heads[r].sizes()[2] == m->num_q_heads); - - // ------------------- Loading CUDA results for this step --------------- - float converted_attn_heads_cpu[num_new_tokens][m->vProjSize] - [m->num_q_heads] = {0}; - for (int i = 0; i < num_new_tokens * m->vProjSize * m->num_q_heads; i++) { - int token_ix = i % num_new_tokens; - int vproj_idx = (i / num_new_tokens) % m->vProjSize; - int head_idx = i / (num_new_tokens * m->vProjSize); - assert(token_ix < num_new_tokens && vproj_idx < m->vProjSize && - head_idx < m->num_q_heads); - converted_attn_heads_cpu[token_ix][vproj_idx][head_idx] = - attn_heads_cpu[r_first_idx[r] * m->vProjSize * m->num_q_heads + i]; - } - torch::Tensor converted_attn_heads_cuda = torch::from_blob( - converted_attn_heads_cpu, - {(int64_t)num_new_tokens, m->vProjSize, m->num_q_heads}, - torch::kFloat32); - - // -------------------- Comparing C++ & CUDA results ------------------- - /* std::cout << "CUDA attn head for req " << r << ":" <num_q_heads; h++) { - std::cout << converted_attn_heads_cuda.index({Slice(), Slice(), h}) << - std::endl; - } - std::cout << "C++ attn head for req " << r << ":" <num_q_heads; h++) { - std::cout << attn_heads[r].index({Slice(), Slice(), h}) << std::endl; - } */ - assert(torch::allclose( - converted_attn_heads_cuda, attn_heads[r], 1e-05, 1e-05)); - - // ----------------------- C++ computations ---------------------------- - // Compute output values by projecting all heads to output space - cpp_output.index( - {Slice(), - Slice(r_first_idx[r], r_first_idx[r] + (int64_t)num_new_tokens)}) = - torch::einsum("jkl,ijk->li", {torch_w_out, attn_heads[r]}); - - // increment main loop's auxiliary index - qk_prods_cpu_offset += - num_new_tokens * num_tokens_received_so_far * num_q_heads; - } - - // ----------------------- Comparing C++ & CUDA results --------------------- - /* std::cout << "C++:" <oProjSize; i++) { - std::cout << cpp_output.index({i, Slice()}) << std::endl; - } - std::cout << "CUDA:" <oProjSize; i++) { - std::cout << torch_out_cuda.index({i, Slice(0, - (int64_t)bc->num_active_tokens())}) << std::endl; - } */ - - assert(torch::allclose( - torch_out_cuda.index( - {Slice(), Slice(0, (int64_t)bc->num_active_tokens())}), - cpp_output, - 1e-05, - 1e-05)); - - // ============================================================================= - // Cleanup - // ============================================================================= - free(w_out); - checkCUDA(cudaFreeHost(input_cpu)); - checkCUDA(cudaFreeHost(weight_cpu)); - checkCUDA(cudaFreeHost(output_cpu)); - checkCUDA(cudaFreeHost(QKVProjArray_cpu)); - checkCUDA(cudaFreeHost(keyCache_cpu)); - checkCUDA(cudaFreeHost(valueCache_cpu)); - checkCUDA(cudaFreeHost(qk_prods_cpu)); - checkCUDA(cudaFreeHost(qk_prods_softmax_cpu)); - checkCUDA(cudaFreeHost(attn_heads_cpu)); - checkCUDA(cudaFreeHost(w_out_cuda)); - // assert(false && "All good if you see this assert failure! :)"); -#endif - // Done with INFERENCE_TESTS block } void IncMultiHeadSelfAttention::backward(FFModel const &ff) { diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index eaaa398654..5a2a14387e 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1097,17 +1097,6 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( bias_ptr = gpu_mem_allocator.allocate_reserved_untyped(biasSize); } -#ifdef INFERENCE_TESTS - kcache = - (float *)calloc(kProjSize * BatchConfig::max_sequence_length() * - num_q_heads * BatchConfig::max_requests_per_batch(), - sizeof(float)); - vcache = - (float *)calloc(vProjSize * BatchConfig::max_sequence_length() * - num_q_heads * BatchConfig::max_requests_per_batch(), - sizeof(float)); -#endif - // allocate memory for the seqArray and reserve space { int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); @@ -1253,10 +1242,6 @@ IncMultiHeadSelfAttentionMeta::~IncMultiHeadSelfAttentionMeta(void) { if (reserveInst != Realm::RegionInstance::NO_INST) { reserveInst.destroy(); } -#ifdef INFERENCE_TESTS - free(kcache); - free(vcache); -#endif } template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( diff --git a/src/ops/kernels/dropout_kernels.cpp b/src/ops/kernels/dropout_kernels.cpp index b0dd4c644e..14225f0bce 100644 --- a/src/ops/kernels/dropout_kernels.cpp +++ b/src/ops/kernels/dropout_kernels.cpp @@ -30,6 +30,7 @@ DropoutMeta::DropoutMeta(FFHandler handler, Domain const &output_domain) : OpMeta(handler) { profiling = dropout->profiling; + inference_debugging = dropout->inference_debugging; checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreateDropoutDescriptor(&dropoutDesc)); diff --git a/src/ops/kernels/dropout_kernels.cu b/src/ops/kernels/dropout_kernels.cu index 4a76301fd6..e142bba83b 100644 --- a/src/ops/kernels/dropout_kernels.cu +++ b/src/ops/kernels/dropout_kernels.cu @@ -29,6 +29,7 @@ DropoutMeta::DropoutMeta(FFHandler handler, Domain const &output_domain) : OpMeta(handler) { profiling = dropout->profiling; + inference_debugging = dropout->inference_debugging; checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); checkCUDNN(cudnnCreateDropoutDescriptor(&dropoutDesc)); diff --git a/src/ops/kernels/element_binary_kernels.cu b/src/ops/kernels/element_binary_kernels.cu index ff5d5a67e1..42b31a664a 100644 --- a/src/ops/kernels/element_binary_kernels.cu +++ b/src/ops/kernels/element_binary_kernels.cu @@ -30,6 +30,7 @@ ElementBinaryMeta::ElementBinaryMeta(FFHandler handler, Op const *op) checkCUDNN(cudnnCreateReduceTensorDescriptor(&reduceAddDesc)); op_type = OP_NOOP; profiling = false; + inference_debugging = false; inplace_a = false; has_same_operands = false; broadcast_input1 = false; diff --git a/src/ops/kernels/softmax.cpp b/src/ops/kernels/softmax.cpp index ca4872d51b..89c9f14a01 100644 --- a/src/ops/kernels/softmax.cpp +++ b/src/ops/kernels/softmax.cpp @@ -34,6 +34,7 @@ SoftmaxMeta::SoftmaxMeta(FFHandler handler, cudnnSetTensorDescriptorFromDomain4SoftMax(outputTensor, input_domain)); dim = softmax->dim; profiling = softmax->profiling; + inference_debugging = softmax->inference_debugging; std::strcpy(op_name, softmax->name); } diff --git a/src/ops/kernels/softmax.cu b/src/ops/kernels/softmax.cu index 67a9c21038..e47006cc9d 100644 --- a/src/ops/kernels/softmax.cu +++ b/src/ops/kernels/softmax.cu @@ -33,6 +33,7 @@ SoftmaxMeta::SoftmaxMeta(FFHandler handler, outputTensor, input_domain, softmax->data_type)); dim = softmax->dim; profiling = softmax->profiling; + inference_debugging = softmax->inference_debugging; std::strcpy(op_name, softmax->name); } diff --git a/src/ops/layer_norm.cc b/src/ops/layer_norm.cc index cb519239c5..bc1358e49c 100644 --- a/src/ops/layer_norm.cc +++ b/src/ops/layer_norm.cc @@ -383,6 +383,8 @@ OpMeta *LayerNorm::init_task(Task const *task, .first(); MemoryAllocator gpu_mem_allocator(gpu_mem); LayerNormMeta *meta = new LayerNormMeta(handle, ln, gpu_mem_allocator); + std::strcpy(meta->op_name, ln->name); + meta->layer_guid = ln->layer_guid; meta->input_type[0] = ln->inputs[0]->data_type; meta->output_type[0] = ln->outputs[0]->data_type; return meta; @@ -504,7 +506,7 @@ void LayerNorm::inference_task(Task const *task, return; } - LayerNormMeta const *m = *((LayerNormMeta **)task->local_args); + LayerNormMeta *m = *((LayerNormMeta **)task->local_args); assert(task->regions.size() == regions.size()); float const *in_ptr = NULL; float *out_ptr = NULL, *gamma_ptr = NULL, *beta_ptr = NULL; @@ -558,7 +560,22 @@ void LayerNorm::inference_task(Task const *task, } else { assert(regions.size() == 2); } + LayerNorm::forward_kernel_wrapper(m, in, out, gamma, beta); + + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector weights_accessors; + if (m->elementwise_affine) { + weights_accessors.push_back(gamma); + if (m->use_bias) { + weights_accessors.push_back(beta); + } + } + LayerNorm::save_inference_tensors_to_file( + m, shard_id, bc, {in}, weights_accessors, {out}); + } } /* @@ -858,6 +875,7 @@ bool LayerNorm::measure_operator_cost(Simulator *sim, void LayerNorm::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->axes.size()); for (size_t i = 0; i < this->axes.size(); i++) { sez.serialize(this->axes[i]); @@ -879,10 +897,11 @@ Node LayerNorm::deserialize(FFModel &ff, bool elementwise_affine; bool use_bias; float eps; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(num_axes); for (size_t i = 0; i < num_axes; i++) { int axis_idx; diff --git a/src/ops/layer_norm.cu b/src/ops/layer_norm.cu index 12751113a2..6d29071e38 100644 --- a/src/ops/layer_norm.cu +++ b/src/ops/layer_norm.cu @@ -33,6 +33,7 @@ LayerNormMeta::LayerNormMeta(FFHandler handle, effective_batch_size = ln->effective_batch_size; effective_num_elements = ln->effective_num_elements; profiling = ln->profiling; + inference_debugging = ln->inference_debugging; eps = ln->eps; DataType data_type = ln->data_type; size_t totalSize = effective_batch_size * data_type_size(data_type) * 6; diff --git a/src/ops/linear.cc b/src/ops/linear.cc index a751ebcc57..63b26bfe7d 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -504,11 +504,13 @@ OpMeta *Linear::init_task_with_dim(Task const *task, m->use_bias = linear->use_bias; m->add_bias_only_once = linear->add_bias_only_once; m->profiling = linear->profiling; + m->inference_debugging = linear->inference_debugging; m->trainableInputs[0] = linear->trainableInputs[0]; m->weight_ptr_type = m->input_type[0]; m->quantization_type = linear->quantization_type; m->offload = linear->offload; std::strcpy(m->op_name, linear->name); + m->layer_guid = linear->layer_guid; init_kernel(m, batch_size, out_dim); @@ -617,7 +619,7 @@ void Linear::inference_task(Task const *task, Runtime *runtime) { Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - LinearMeta const *m = *((LinearMeta **)task->local_args); + LinearMeta *m = *((LinearMeta **)task->local_args); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_tokens == 0) { return; @@ -658,6 +660,18 @@ void Linear::inference_task(Task const *task, in_dim, out_dim, batch_size); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector weights_accessors; + weights_accessors.push_back(weight); + if (m->use_bias && + !(m->add_bias_only_once && task->index_point.point_data[0] != 0)) { + weights_accessors.push_back(bias); + } + Linear::save_inference_tensors_to_file( + m, shard_id, bc, {input}, weights_accessors, {output}); + } } void Linear::forward_task(Task const *task, @@ -1235,6 +1249,7 @@ bool operator==(LinearParams const &lhs, LinearParams const &rhs) { void Linear::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->out_channels); sez.serialize(this->activation); sez.serialize(this->kernel_reg_type); @@ -1260,10 +1275,11 @@ Node Linear::deserialize(FFModel &ff, DataType data_type; DataType quantization_type; bool offload; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(out_channels); dez.deserialize(activation); dez.deserialize(kernel_reg_type); diff --git a/src/ops/pool_2d.cc b/src/ops/pool_2d.cc index f56a60641d..e358448ddf 100644 --- a/src/ops/pool_2d.cc +++ b/src/ops/pool_2d.cc @@ -317,7 +317,9 @@ OpMeta *Pool2D::init_task(Task const *task, FFHandler handle = *((FFHandler const *)task->local_args); Pool2DMeta *m = new Pool2DMeta(handle); m->profiling = pool->profiling; + m->inference_debugging = pool->inference_debugging; std::strcpy(m->op_name, pool->name); + m->layer_guid = pool->layer_guid; TensorAccessorR acc_input( regions[0], task->regions[0], FID_DATA, ctx, runtime); TensorAccessorW acc_output(regions[1], diff --git a/src/ops/reduce.cc b/src/ops/reduce.cc index 6c999c8858..7a443e6ad0 100644 --- a/src/ops/reduce.cc +++ b/src/ops/reduce.cc @@ -210,6 +210,8 @@ OpMeta *Reduce::init_task(Task const *task, GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( DT_FLOAT, regions[1], task->regions[1], FID_DATA, ctx, runtime); ReduceMeta *m = new ReduceMeta(handle, rd, input.domain); + std::strcpy(m->op_name, rd->name); + m->layer_guid = rd->layer_guid; return m; } @@ -375,6 +377,7 @@ void Reduce::serialize(Legion::Serializer &sez) const { sez.serialize(params.keepdims); sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); } using PCG::Node; @@ -393,10 +396,11 @@ Node Reduce::deserialize(FFModel &ff, axes.push_back(dim_idx); } dez.deserialize(keepdims); - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); return ff.get_or_create_node(inputs[0], {axes, keepdims, layer_guid}); } diff --git a/src/ops/reshape.cc b/src/ops/reshape.cc index 41c3fcdbf1..45da190680 100644 --- a/src/ops/reshape.cc +++ b/src/ops/reshape.cc @@ -181,6 +181,8 @@ OpMeta *Reshape::init_task(Task const *task, Reshape const *reshape = (Reshape *)task->args; FFHandler handle = *((FFHandler const *)task->local_args); ReshapeMeta *m = new ReshapeMeta(handle); + std::strcpy(m->op_name, reshape->name); + m->layer_guid = reshape->layer_guid; m->data_type = reshape->outputs[0]->data_type; return m; } @@ -411,6 +413,7 @@ void Reshape::serialize(Legion::Serializer &sez) const { } sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); } using PCG::Node; @@ -428,10 +431,11 @@ Node Reshape::deserialize(FFModel &ff, dez.deserialize(value); shape.push_back(value); } - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); ReshapeParams params; params.shape = shape; diff --git a/src/ops/residual_layer_norm.cc b/src/ops/residual_layer_norm.cc index 3bec09521a..7de40fb389 100644 --- a/src/ops/residual_layer_norm.cc +++ b/src/ops/residual_layer_norm.cc @@ -493,6 +493,8 @@ OpMeta *ResidualLayerNorm::init_task(Task const *task, MemoryAllocator gpu_mem_allocator(gpu_mem); ResidualLayerNormMeta *meta = new ResidualLayerNormMeta(handle, ln, gpu_mem_allocator); + std::strcpy(meta->op_name, ln->name); + meta->layer_guid = ln->layer_guid; meta->input_type[0] = ln->inputs[0]->data_type; meta->input_type[1] = ln->inputs[1]->data_type; if (ln->use_two_residuals) { @@ -622,8 +624,7 @@ void ResidualLayerNorm::inference_task( return; } - ResidualLayerNormMeta const *m = - *((ResidualLayerNormMeta **)task->local_args); + ResidualLayerNormMeta *m = *((ResidualLayerNormMeta **)task->local_args); assert(regions.size() == 4 + m->use_two_residuals + @@ -734,6 +735,30 @@ void ResidualLayerNorm::inference_task( ResidualLayerNorm::inference_kernel_wrapper( m, input, residual1, residual2, added_output, output, gamma, beta); + + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector input_accessors; + input_accessors.push_back(input); + input_accessors.push_back(residual1); + if (m->use_two_residuals) { + input_accessors.push_back(residual2); + } + std::vector weights_accessors; + if (m->elementwise_affine) { + weights_accessors.push_back(gamma); + if (m->use_bias) { + weights_accessors.push_back(beta); + } + } + ResidualLayerNorm::save_inference_tensors_to_file(m, + shard_id, + bc, + input_accessors, + weights_accessors, + {added_output, output}); + } } bool ResidualLayerNorm::measure_operator_cost(Simulator *sim, @@ -745,6 +770,7 @@ bool ResidualLayerNorm::measure_operator_cost(Simulator *sim, void ResidualLayerNorm::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->axes.size()); for (size_t i = 0; i < this->axes.size(); i++) { sez.serialize(this->axes[i]); @@ -767,10 +793,11 @@ Node ResidualLayerNorm::deserialize(FFModel &ff, bool use_bias; bool use_two_residuals; float eps; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(num_axes); for (size_t i = 0; i < num_axes; i++) { int axis_idx; @@ -811,6 +838,7 @@ size_t hash::operator()( size_t key = 0; hash_combine(key, params.layer_guid.id); hash_combine(key, params.layer_guid.transformer_layer_id); + hash_combine(key, params.layer_guid.model_id); hash_combine(key, params.axes.size()); for (int n : params.axes) { hash_combine(key, n); diff --git a/src/ops/residual_layer_norm.cpp b/src/ops/residual_layer_norm.cpp index 7b42392326..f1b7a537b0 100644 --- a/src/ops/residual_layer_norm.cpp +++ b/src/ops/residual_layer_norm.cpp @@ -34,6 +34,7 @@ ResidualLayerNormMeta::ResidualLayerNormMeta(FFHandler handle, effective_batch_size = ln->effective_batch_size; effective_num_elements = ln->effective_num_elements; profiling = ln->profiling; + inference_debugging = ln->inference_debugging; eps = ln->eps; DataType data_type = ln->data_type; size_t totalSize = effective_batch_size * data_type_size(data_type) * 3; diff --git a/src/ops/residual_layer_norm.cu b/src/ops/residual_layer_norm.cu index 3bd18217cb..e5ebdce6ed 100644 --- a/src/ops/residual_layer_norm.cu +++ b/src/ops/residual_layer_norm.cu @@ -33,6 +33,7 @@ ResidualLayerNormMeta::ResidualLayerNormMeta(FFHandler handle, effective_batch_size = ln->effective_batch_size; effective_num_elements = ln->effective_num_elements; profiling = ln->profiling; + inference_debugging = ln->inference_debugging; eps = ln->eps; DataType data_type = ln->data_type; size_t totalSize = effective_batch_size * data_type_size(data_type) * 3; diff --git a/src/ops/residual_rms_norm.cc b/src/ops/residual_rms_norm.cc index 3efb7274a0..b447a2a3b5 100644 --- a/src/ops/residual_rms_norm.cc +++ b/src/ops/residual_rms_norm.cc @@ -351,6 +351,8 @@ OpMeta *ResidualRMSNorm::init_task(Task const *task, MemoryAllocator gpu_mem_allocator(gpu_mem); ResidualRMSNormMeta *meta = new ResidualRMSNormMeta(handle, rn, gpu_mem_allocator); + std::strcpy(meta->op_name, rn->name); + meta->layer_guid = rn->layer_guid; return meta; } @@ -431,7 +433,7 @@ void ResidualRMSNorm::inference_task(Task const *task, if (bc->num_tokens == 0) { return; } - ResidualRMSNormMeta const *m = *((ResidualRMSNormMeta **)task->local_args); + ResidualRMSNormMeta *m = *((ResidualRMSNormMeta **)task->local_args); GenericTensorAccessorR input1 = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorR input2 = helperGetGenericTensorAccessorRO( @@ -443,11 +445,18 @@ void ResidualRMSNorm::inference_task(Task const *task, GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( m->weight_type[0], regions[4], task->regions[4], FID_DATA, ctx, runtime); forward_kernel_wrapper(m, input1, input2, weight, residual_output, output); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + ResidualRMSNorm::save_inference_tensors_to_file( + m, shard_id, bc, {input1, input2}, {weight}, {residual_output, output}); + } } void ResidualRMSNorm::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->eps); sez.serialize(this->dim); } @@ -460,12 +469,12 @@ Node ResidualRMSNorm::deserialize(FFModel &ff, int num_inputs) { assert(num_inputs == 2); float eps; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; int dim; dez.deserialize(id); dez.deserialize(transformer_layer_id); - - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(eps); dez.deserialize(dim); ResidualRMSNormParams params; diff --git a/src/ops/rms_norm.cc b/src/ops/rms_norm.cc index 5b1634472d..2a34f83be2 100644 --- a/src/ops/rms_norm.cc +++ b/src/ops/rms_norm.cc @@ -297,6 +297,8 @@ OpMeta *RMSNorm::init_task(Task const *task, .first(); MemoryAllocator gpu_mem_allocator(gpu_mem); RMSNormMeta *meta = new RMSNormMeta(handle, rn, gpu_mem_allocator); + std::strcpy(meta->op_name, rn->name); + meta->layer_guid = rn->layer_guid; return meta; } @@ -413,7 +415,7 @@ void RMSNorm::inference_task(Task const *task, if (bc->num_tokens == 0) { return; } - RMSNormMeta const *m = *((RMSNormMeta **)task->local_args); + RMSNormMeta *m = *((RMSNormMeta **)task->local_args); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( @@ -421,11 +423,18 @@ void RMSNorm::inference_task(Task const *task, GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( m->weight_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); forward_kernel_wrapper(m, input, weight, output); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + RMSNorm::save_inference_tensors_to_file( + m, shard_id, bc, {input}, {weight}, {output}); + } } void RMSNorm::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); sez.serialize(this->eps); sez.serialize(this->dim); } @@ -438,12 +447,13 @@ Node RMSNorm::deserialize(FFModel &ff, int num_inputs) { assert(num_inputs == 1); float eps; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; int dim; dez.deserialize(id); dez.deserialize(transformer_layer_id); + dez.deserialize(deserialized_model_id); - LayerID layer_guid(id, transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(eps); dez.deserialize(dim); RMSNormParams params; diff --git a/src/ops/sampling.cc b/src/ops/sampling.cc index 6eb62b2933..463b15aadb 100644 --- a/src/ops/sampling.cc +++ b/src/ops/sampling.cc @@ -231,6 +231,9 @@ OpMeta *Sampling::init_task(Task const *task, SamplingMeta *m = new SamplingMeta( handle, s, batch_size, length * batch_size, acc_input, gpu_mem_allocator); m->profiling = s->profiling; + m->inference_debugging = s->inference_debugging; + std::strcpy(m->op_name, s->name); + m->layer_guid = s->layer_guid; m->top_p = s->top_p; return m; } @@ -287,7 +290,7 @@ InferenceResult assert(task->regions.size() == 2); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); // BatchConfig const *bc = (BatchConfig *)task->args; - SamplingMeta const *m = *((SamplingMeta **)task->local_args); + SamplingMeta *m = *((SamplingMeta **)task->local_args); if (bc->num_tokens == 0) { // Directly return for empty batch config InferenceResult ir; @@ -302,6 +305,13 @@ InferenceResult int batch_size = bc->num_active_tokens(); Sampling::forward_kernel_wrapper(m, input, indices, batch_size); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + Sampling::save_inference_tensors_to_file( + m, shard_id, bc, {}, {}, {input, indices}); + } + InferenceResult ir; download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); diff --git a/src/ops/sigmoid_silu_multi.cc b/src/ops/sigmoid_silu_multi.cc index 031c7833a4..3b2ed7cef4 100644 --- a/src/ops/sigmoid_silu_multi.cc +++ b/src/ops/sigmoid_silu_multi.cc @@ -244,6 +244,8 @@ OpMeta *SigmoidSiluMulti::init_task(Task const *task, meta->input_type[0] = ssm->inputs[0]->data_type; meta->input_type[1] = ssm->inputs[1]->data_type; meta->output_type[0] = ssm->outputs[0]->data_type; + std::strcpy(meta->op_name, ssm->name); + meta->layer_guid = ssm->layer_guid; return meta; } @@ -323,7 +325,7 @@ void SigmoidSiluMulti::inference_task( return; } - SigmoidSiluMultiMeta const *m = *((SigmoidSiluMultiMeta **)task->local_args); + SigmoidSiluMultiMeta *m = *((SigmoidSiluMultiMeta **)task->local_args); GenericTensorAccessorR input1 = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); @@ -346,6 +348,12 @@ void SigmoidSiluMulti::inference_task( assert(input1_domain == output_domain); SigmoidSiluMulti::inference_kernel_wrapper(m, input1, input2, output); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + SigmoidSiluMulti::save_inference_tensors_to_file( + m, shard_id, bc, {input1, input2}, {}, {output}); + } } bool SigmoidSiluMulti::measure_operator_cost(Simulator *sim, @@ -357,6 +365,7 @@ bool SigmoidSiluMulti::measure_operator_cost(Simulator *sim, void SigmoidSiluMulti::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->layer_guid.model_id); } using PCG::Node; @@ -366,10 +375,11 @@ Node SigmoidSiluMulti::deserialize(FFModel &ff, ParallelTensor inputs[], int num_inputs) { assert(num_inputs == 2); - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); SigmoidSiluMultiParams params; params.layer_guid = layer_guid; @@ -385,6 +395,7 @@ size_t hash::operator()( size_t key = 0; hash_combine(key, params.layer_guid.id); hash_combine(key, params.layer_guid.transformer_layer_id); + hash_combine(key, params.layer_guid.model_id); return key; } }; // namespace std diff --git a/src/ops/sigmoid_silu_multi.cpp b/src/ops/sigmoid_silu_multi.cpp index fa73a55722..7b7f30a288 100644 --- a/src/ops/sigmoid_silu_multi.cpp +++ b/src/ops/sigmoid_silu_multi.cpp @@ -25,6 +25,7 @@ SigmoidSiluMultiMeta::SigmoidSiluMultiMeta(FFHandler handle, MemoryAllocator &gpu_mem_allocator) : OpMeta(handle) { profiling = ssm->profiling; + inference_debugging = ssm->inference_debugging; } SigmoidSiluMultiMeta::~SigmoidSiluMultiMeta(void) { diff --git a/src/ops/sigmoid_silu_multi.cu b/src/ops/sigmoid_silu_multi.cu index fd69f6a8aa..590b641b5a 100644 --- a/src/ops/sigmoid_silu_multi.cu +++ b/src/ops/sigmoid_silu_multi.cu @@ -24,6 +24,7 @@ SigmoidSiluMultiMeta::SigmoidSiluMultiMeta(FFHandler handle, MemoryAllocator &gpu_mem_allocator) : OpMeta(handle) { profiling = ssm->profiling; + inference_debugging = ssm->inference_debugging; } SigmoidSiluMultiMeta::~SigmoidSiluMultiMeta(void) { diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index 28c9ecea67..ba0a1288d6 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -230,6 +230,8 @@ OpMeta *Softmax::init_task(Task const *task, m->input_type = softmax->inputs[0]->data_type; m->output_type = softmax->outputs[0]->data_type; // checkCUDNN(cudnnCreateTensorDescriptor(&m->outputTensor)); + std::strcpy(m->op_name, softmax->name); + m->layer_guid = softmax->layer_guid; return m; } @@ -303,50 +305,25 @@ void Softmax::forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); Domain in_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); SoftmaxMeta const *m = *((SoftmaxMeta **)task->local_args); - switch (in_domain.get_dim()) { -#define DIMFUNC(DIM) \ - case DIM: \ - if (m->output_type == DT_HALF) { \ - return forward_task_with_dim(task, regions, ctx, runtime); \ - } else if (m->output_type == DT_FLOAT) { \ - return forward_task_with_dim(task, regions, ctx, runtime); \ - } else { \ - assert(false && "Unsupported data type"); \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->output_type, regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type, regions[1], task->regions[1], FID_DATA, ctx, runtime); + + if (m->output_type == DT_HALF) { + forward_kernel_wrapper(m, input.get_half_ptr(), output.get_half_ptr()); + } else if (m->output_type == DT_FLOAT) { + forward_kernel_wrapper(m, input.get_float_ptr(), output.get_float_ptr()); + } else { + assert(false && "Unsupported data type"); } } -/* - regions[0](I): input - regions[1](O): output -*/ -template -void Softmax::forward_task_with_dim(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - // const Softmax* softmax = (Softmax*) task->args; - SoftmaxMeta const *m = *((SoftmaxMeta **)task->local_args); - TensorAccessorR acc_input( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - TensorAccessorW acc_output(regions[1], - task->regions[1], - FID_DATA, - ctx, - runtime, - false /*readOutput*/); - forward_kernel_wrapper(m, acc_input.ptr, acc_output.ptr); -} - void Softmax::backward(FFModel const &ff) { ArgumentMap argmap; Context ctx = ff.config.lg_ctx; @@ -435,29 +412,31 @@ void Softmax::inference_task(Task const *task, Context ctx, Runtime *runtime) { assert(task->regions.size() == regions.size()); + assert(regions.size() == 2); + assert(task->regions.size() == 2); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_tokens == 0) { return; } Domain in_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - SoftmaxMeta const *m = *((SoftmaxMeta **)task->local_args); - switch (in_domain.get_dim()) { -#define DIMFUNC(DIM) \ - case DIM: \ - if (m->output_type == DT_HALF) { \ - forward_task_with_dim(task, regions, ctx, runtime); \ - break; \ - } else if (m->output_type == DT_FLOAT) { \ - forward_task_with_dim(task, regions, ctx, runtime); \ - break; \ - } else { \ - assert(false && "Unsupported data type"); \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); + SoftmaxMeta *m = *((SoftmaxMeta **)task->local_args); + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->output_type, regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type, regions[1], task->regions[1], FID_DATA, ctx, runtime); + if (m->output_type == DT_HALF) { + forward_kernel_wrapper(m, input.get_half_ptr(), output.get_half_ptr()); + } else if (m->output_type == DT_FLOAT) { + forward_kernel_wrapper(m, input.get_float_ptr(), output.get_float_ptr()); + } else { + assert(false && "Unsupported data type"); + } + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + Softmax::save_inference_tensors_to_file( + m, shard_id, bc, {input}, {}, {output}); } } diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index 350ab3c167..4c78960d5f 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -658,6 +658,9 @@ OpMeta *SpecIncMultiHeadSelfAttention::init_task( assert(gpu_mem_allocator.instance_allocated_size == gpu_mem_allocator.instance_total_size); m->profiling = attn->profiling; + m->inference_debugging = attn->inference_debugging; + std::strcpy(m->op_name, attn->name); + m->layer_guid = attn->layer_guid; assert(weight.domain.get_volume() * data_type_size(weight.data_type) == m->weightSize); return m; @@ -733,14 +736,13 @@ void SpecIncMultiHeadSelfAttention::inference_task( Runtime *runtime) { assert(task->regions.size() == regions.size()); - // BeamSearchBatchConfig const *bc = (BeamSearchBatchConfig *)task->args; BeamSearchBatchConfig const &bc = Future(task->futures[0]).get_result(); if (bc.num_tokens == 0) { return; } - SpecIncMultiHeadSelfAttentionMeta const *m = + SpecIncMultiHeadSelfAttentionMeta *m = *((SpecIncMultiHeadSelfAttentionMeta **)task->local_args); assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 : regions.size() == 3)); @@ -777,14 +779,17 @@ void SpecIncMultiHeadSelfAttention::inference_task( assert(task->index_point.get_dim() == 1); SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( m, &bc, task->index_point.point_data[0], input, weight, output, biases); - - // print_tensor(input.get_float_ptr(), 20, "attention input"); - // print_tensor(output.get_float_ptr(), 20, "attention output"); - // if(bc.beam_slots.at(0).current_depth == 1){ - // print_beam_tensor(input.get_float_ptr(), 50, 4096, 40, "mha topk - // input"); print_beam_tensor(output.get_float_ptr(), 50, 4096, 40, - // "mha topk output"); - // } + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector weights_accessors; + weights_accessors.push_back(weight); + if (*m->qkv_bias || *m->final_bias) { + weights_accessors.push_back(biases); + } + SpecIncMultiHeadSelfAttention::save_inference_tensors_to_file( + m, shard_id, &bc, {input}, weights_accessors, {output}); + } } void SpecIncMultiHeadSelfAttention::backward(FFModel const &ff) { diff --git a/src/ops/topk.cc b/src/ops/topk.cc index d76ad75167..b38ff85f90 100644 --- a/src/ops/topk.cc +++ b/src/ops/topk.cc @@ -225,7 +225,10 @@ OpMeta *TopK::init_task(Task const *task, FFHandler handle = *((FFHandler *)task->local_args); TopKMeta *m = new TopKMeta(handle); m->profiling = topk->profiling; + m->inference_debugging = topk->inference_debugging; m->sorted = topk->sorted; + std::strcpy(m->op_name, topk->name); + m->layer_guid = topk->layer_guid; return m; } diff --git a/src/ops/transpose.cc b/src/ops/transpose.cc index 303948964b..500b7867af 100644 --- a/src/ops/transpose.cc +++ b/src/ops/transpose.cc @@ -193,6 +193,9 @@ OpMeta *Transpose::init_task(Task const *task, TransposeMeta *m = new TransposeMeta(handle); transpose->init_meta(m, in_domain, out_domain); m->profiling = transpose->profiling; + m->inference_debugging = transpose->inference_debugging; + std::strcpy(m->op_name, transpose->name); + m->layer_guid = transpose->layer_guid; return m; } diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index 207dae0785..6b520aa37b 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -722,6 +722,9 @@ OpMeta *TreeIncMultiHeadSelfAttention::init_task( gpu_mem_allocator.reserved_total_size); } m->profiling = attn->profiling; + m->inference_debugging = attn->inference_debugging; + std::strcpy(m->op_name, attn->name); + m->layer_guid = attn->layer_guid; if (attn->quantization_type == DT_NONE) { assert(weight.domain.get_volume() * data_type_size(weight.data_type) == @@ -803,7 +806,6 @@ void TreeIncMultiHeadSelfAttention::inference_task( Runtime *runtime) { assert(task->regions.size() == regions.size()); - // TreeVerifyBatchConfig const *bc = (TreeVerifyBatchConfig *)task->args; TreeVerifyBatchConfig const &bc = Future(task->futures[0]).get_result(); log_tree_verify.debug( @@ -857,6 +859,18 @@ void TreeIncMultiHeadSelfAttention::inference_task( TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( m, &bc, task->index_point.point_data[0], input, weight, output, biases); + + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector weights_accessors; + weights_accessors.push_back(weight); + if (*m->qkv_bias || *m->final_bias) { + weights_accessors.push_back(biases); + } + TreeIncMultiHeadSelfAttention::save_inference_tensors_to_file( + m, shard_id, &bc, {input}, weights_accessors, {output}); + } } void TreeIncMultiHeadSelfAttention::backward(FFModel const &ff) { diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index cbf839c6b2..72572c4e06 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -88,43 +88,61 @@ int BatchConfig::max_sequence_length() { return RequestManager::get_request_manager()->get_max_sequence_length(); } -void BatchConfig::print() const { - std::cout << "@@@@@@@@@@@@@@ Batch Config (mode " << get_mode() - << ") @@@@@@@@@@@@@@" << std::endl; - std::cout << "Max number of requests: " << max_requests_per_batch() - << std::endl; - std::cout << "Max number of tokens: " << max_tokens_per_batch() << std::endl; - std::cout << "Number of tokens: " << num_tokens << std::endl; - std::cout << "Number of requests: " << num_active_requests() << std::endl; - // std::cout << "Cached results: " << cached_results << std::endl; - - std::cout << "Per-request info:\n"; - for (int i = 0; i < max_requests_per_batch(); i++) { - if (!request_completed[i]) { - std::cout << " Request " << i << ":\n"; - std::cout << " Token start offset: " - << requestsInfo[i].token_start_offset << std::endl; - std::cout << " Number of tokens in batch: " - << requestsInfo[i].num_tokens_in_batch << std::endl; - std::cout << " GUID: " << requestsInfo[i].request_guid << std::endl; - std::cout << " Max sequence length: " - << requestsInfo[i].max_sequence_length << std::endl; - std::cout << " Request completed: " << request_completed[i] - << std::endl; +std::ostream &operator<<(std::ostream &os, BatchConfig const &bc) { + os << "@@@@@@@@@@@@@@ Batch Config (mode " << bc.get_mode() + << ") @@@@@@@@@@@@@@" << std::endl; + // Max values + os << "Max number of requests: " << bc.max_requests_per_batch() << std::endl; + os << "Max number of tokens: " << bc.max_tokens_per_batch() << std::endl; + os << "Max sequence length: " << bc.max_sequence_length() << std::endl; + // Current values + os << "Number of tokens: " << bc.num_active_tokens() << std::endl; + os << "Number of requests: " << bc.num_active_requests() << std::endl; + + // Per-request info + os << "Per-request info:\n"; + for (int i = 0; i < bc.max_requests_per_batch(); i++) { + if (!bc.request_completed[i]) { + os << " Request " << i << ":\n"; + os << " Token start offset: " << bc.requestsInfo[i].token_start_offset + << std::endl; + os << " Number of tokens in batch: " + << bc.requestsInfo[i].num_tokens_in_batch << std::endl; + os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl; + os << " Max sequence length: " + << bc.requestsInfo[i].max_sequence_length << std::endl; + os << " Request completed: " << bc.request_completed[i] << std::endl; + os << " Request running: " << bc.request_running[i] << std::endl; } } - std::cout << "Per-token info:\n"; - for (int i = 0; i < num_tokens; i++) { - std::cout << " Token " << i << ":\n"; - std::cout << " Absolute depth in request: " - << tokensInfo[i].abs_depth_in_request << std::endl; - std::cout << " Request index: " << tokensInfo[i].request_index - << std::endl; - std::cout << " Token id: " << tokensInfo[i].token_id << std::endl; + // Per-token info + os << "Per-token info:\n"; + for (int i = 0; i < bc.num_tokens; i++) { + os << " Token " << i << ":\n"; + os << " Absolute depth in request: " + << bc.tokensInfo[i].abs_depth_in_request << std::endl; + os << " Request index: " << bc.tokensInfo[i].request_index << std::endl; + os << " Token id: " << bc.tokensInfo[i].token_id << std::endl; + } + os << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << std::endl; + return os; +} + +void BatchConfig::print() const { + std::cout << *this << std::endl; +} + +void BatchConfig::save_to_file(std::string const &filename) const { + std::ofstream outputFile(filename); + if (outputFile.is_open()) { + outputFile << *this << std::endl; + outputFile.close(); + } else { + std::cerr << "Error: Unable to open the batch config output file: " + << filename << std::endl; + assert(false); } - std::cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" - << std::endl; } }; // namespace FlexFlow diff --git a/src/runtime/beam_search_batch_config.cc b/src/runtime/beam_search_batch_config.cc index 634d60a352..811ef00ba2 100644 --- a/src/runtime/beam_search_batch_config.cc +++ b/src/runtime/beam_search_batch_config.cc @@ -14,6 +14,7 @@ */ #include "flexflow/batch_config.h" +#include "flexflow/request_manager.h" #include "legion.h" #include #include @@ -101,72 +102,93 @@ int BeamSearchBatchConfig::current_depth_all_requests() const { return current_depth; } -void BeamSearchBatchConfig::print() const { - std::cout << "@@@@@@@@@@@@@@ BeamSearchBatchConfig (mode " << get_mode() - << ") @@@@@@@@@@@@@@" << std::endl; - std::cout << "Max number of requests: " << max_requests_per_batch() - << std::endl; - std::cout << "Max number of tokens: " << max_tokens_per_batch() << std::endl; - std::cout << "Number of tokens: " << num_tokens << std::endl; - std::cout << "Number of requests: " << num_active_requests() << std::endl; - std::cout << "Beam width: " << beam_width << std::endl; - std::cout << "Target Iterations: " << target_iterations << std::endl; - std::cout << "Current Iterations: " << current_iteration << std::endl; - - std::cout << "Per-request info:\n"; - for (int i = 0; i < max_requests_per_batch(); i++) { - // assert(beamRequestsInfo[i].request_completed == request_completed[i]); - if (!request_completed[i]) { - std::cout << " Request " << i << ":\n"; - std::cout << " Token start offset: " - << requestsInfo[i].token_start_offset << std::endl; - std::cout << " Number of tokens in batch: " - << requestsInfo[i].num_tokens_in_batch << std::endl; - std::cout << " GUID: " << requestsInfo[i].request_guid << std::endl; - std::cout << " Max sequence length: " - << requestsInfo[i].max_sequence_length << std::endl; - std::cout << " Beam Search Specific: " << std::endl; - std::cout << " beam_size: " << beamRequestsInfo[i].beam_size - << std::endl; - std::cout << " current_depth: " - << beamRequestsInfo[i].current_depth << std::endl; - std::cout << " max_depth: " << beamRequestsInfo[i].max_depth - << std::endl; - std::cout << " tokens: "; - for (int j = 0; j < MAX_BEAM_WIDTH; j++) { - std::cout << beamRequestsInfo[i].tokens[j] << ", "; +std::ostream &operator<<(std::ostream &os, BeamSearchBatchConfig const &bc) { + os << "@@@@@@@@@@@@@@ BeamSearchBatchConfig (mode " << bc.get_mode() + << ") @@@@@@@@@@@@@@" << std::endl; + // Max values + os << "Max number of requests: " << bc.max_requests_per_batch() << std::endl; + os << "Max number of tokens: " << bc.max_tokens_per_batch() << std::endl; + os << "Max sequence length: " << bc.max_sequence_length() << std::endl; + // Current values + os << "Number of tokens: " << bc.num_active_tokens() << std::endl; + os << "Number of requests: " << bc.num_active_requests() << std::endl; + // BeamSearch-specific + os << "Model ID: " << bc.model_id << std::endl; + os << "Max Beam Depth (all requests): " << bc.max_beam_depth_all_requests() + << std::endl; + os << "Current depth (all requests): " << bc.current_depth_all_requests() + << std::endl; + os << "Beam width: " << bc.beam_width << std::endl; + os << "Target Iterations: " << bc.target_iterations << std::endl; + os << "Current Iterations: " << bc.current_iteration << std::endl; + + os << "Per-request info:\n"; + for (int i = 0; i < bc.max_requests_per_batch(); i++) { + if (!bc.request_completed[i]) { + os << " Request " << i << ":\n"; + os << " Token start offset: " << bc.requestsInfo[i].token_start_offset + << std::endl; + os << " Number of tokens in batch: " + << bc.requestsInfo[i].num_tokens_in_batch << std::endl; + os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl; + os << " Max sequence length: " + << bc.requestsInfo[i].max_sequence_length << std::endl; + os << " Request completed: " << bc.request_completed[i] << std::endl; + os << " Request running: " << bc.request_running[i] << std::endl; + os << " Beam Search Specific: " << std::endl; + os << " beam_size: " << bc.beamRequestsInfo[i].beam_size + << std::endl; + os << " current_depth: " << bc.beamRequestsInfo[i].current_depth + << std::endl; + os << " max_depth: " << bc.beamRequestsInfo[i].max_depth + << std::endl; + os << " tokens: "; + for (int j = 0; j < bc.MAX_BEAM_WIDTH; j++) { + os << bc.beamRequestsInfo[i].tokens[j] << ", "; } - std::cout << std::endl; - std::cout << " probs: "; - for (int j = 0; j < MAX_BEAM_WIDTH; j++) { - std::cout << beamRequestsInfo[i].probs[j] << ", "; + os << std::endl; + os << " probs: "; + for (int j = 0; j < bc.MAX_BEAM_WIDTH; j++) { + os << bc.beamRequestsInfo[i].probs[j] << ", "; } - std::cout << std::endl; - std::cout << " parent_id: "; - for (int j = 0; j < MAX_BEAM_WIDTH; j++) { - std::cout << beamRequestsInfo[i].parent_id[j] << ", "; + os << std::endl; + os << " parent_id: "; + for (int j = 0; j < bc.MAX_BEAM_WIDTH; j++) { + os << bc.beamRequestsInfo[i].parent_id[j] << ", "; } - std::cout << std::endl; + os << std::endl; } } - std::cout << "Per-token info:\n"; - for (int i = 0; i < num_tokens; i++) { - std::cout << " Token " << i << ":\n"; - std::cout << " Absolute depth in request: " - << tokensInfo[i].abs_depth_in_request << std::endl; - std::cout << " Request index: " << tokensInfo[i].request_index - << std::endl; - std::cout << " Token id: " << tokensInfo[i].token_id << std::endl; - std::cout << " Beam Search Specific: " << std::endl; - std::cout << " beam_size: " << beamTokenInfo[i].sub_request_index - << std::endl; - // std::cout << " Parent token id: " << tokensInfo[i].parent_token_id << - // std::endl; std::cout << " Accumulated log prob: " - // << tokensInfo[i].cum_log_prob << std::endl; + os << "Per-token info:\n"; + for (int i = 0; i < bc.num_tokens; i++) { + os << " Token " << i << ":\n"; + os << " Absolute depth in request: " + << bc.tokensInfo[i].abs_depth_in_request << std::endl; + os << " Request index: " << bc.tokensInfo[i].request_index << std::endl; + os << " Token id: " << bc.tokensInfo[i].token_id << std::endl; + os << " Beam Search Specific: " << std::endl; + os << " beam_size: " << bc.beamTokenInfo[i].sub_request_index + << std::endl; + } + os << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << std::endl; + return os; +} + +void BeamSearchBatchConfig::print() const { + std::cout << *this << std::endl; +} + +void BeamSearchBatchConfig::save_to_file(std::string const &filename) const { + std::ofstream outputFile(filename); + if (outputFile.is_open()) { + outputFile << *this << std::endl; + outputFile.close(); + } else { + std::cerr << "Error: Unable to open the batch config output file: " + << filename << std::endl; + assert(false); } - std::cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" - << std::endl; } }; // namespace FlexFlow diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index e4728bdb88..fa6bf55fe5 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -268,23 +268,106 @@ __host__ void print_beam_tensor(T const *ptr, checkCUDA(cudaFreeHost(host_ptr)); } -template +template <> __host__ void - save_tensor(T const *ptr, size_t num_elements, char const *file_name) { + save_tensor(float const *ptr, size_t num_elements, char const *file_name) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - T *host_ptr; + float *host_ptr; checkCUDA(cudaHostAlloc(&host_ptr, - sizeof(T) * num_elements, + sizeof(float) * num_elements, cudaHostAllocPortable | cudaHostAllocMapped)); - checkCUDA(cudaMemcpyAsync( - host_ptr, ptr, sizeof(T) * num_elements, cudaMemcpyDeviceToHost, stream)); - // checkCUDA(cudaDeviceSynchronize()); - cudaDeviceSynchronize(); + checkCUDA(cudaMemcpyAsync(host_ptr, + ptr, + sizeof(float) * num_elements, + cudaMemcpyDeviceToHost, + stream)); + checkCUDA(cudaDeviceSynchronize()); FILE *tensor_file; tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); for (unsigned i = 0; i < num_elements; i++) { - fprintf(tensor_file, "%.20f, ", (float)host_ptr[i]); + fprintf(tensor_file, "%.9f, ", host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(cudaFreeHost(host_ptr)); +} + +template <> +__host__ void + save_tensor(half const *ptr, size_t num_elements, char const *file_name) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + half *host_ptr; + checkCUDA(cudaHostAlloc(&host_ptr, + sizeof(half) * num_elements, + cudaHostAllocPortable | cudaHostAllocMapped)); + checkCUDA(cudaMemcpyAsync(host_ptr, + ptr, + sizeof(half) * num_elements, + cudaMemcpyDeviceToHost, + stream)); + checkCUDA(cudaDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%.9f, ", (float)host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(cudaFreeHost(host_ptr)); +} + +template <> +__host__ void save_tensor(int32_t const *ptr, + size_t num_elements, + char const *file_name) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + int32_t *host_ptr; + checkCUDA(cudaHostAlloc(&host_ptr, + sizeof(int32_t) * num_elements, + cudaHostAllocPortable | cudaHostAllocMapped)); + checkCUDA(cudaMemcpyAsync(host_ptr, + ptr, + sizeof(int32_t) * num_elements, + cudaMemcpyDeviceToHost, + stream)); + checkCUDA(cudaDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%d, ", host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(cudaFreeHost(host_ptr)); +} + +template <> +__host__ void save_tensor(int64_t const *ptr, + size_t num_elements, + char const *file_name) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + int64_t *host_ptr; + checkCUDA(cudaHostAlloc(&host_ptr, + sizeof(int64_t) * num_elements, + cudaHostAllocPortable | cudaHostAllocMapped)); + checkCUDA(cudaMemcpyAsync(host_ptr, + ptr, + sizeof(int64_t) * num_elements, + cudaMemcpyDeviceToHost, + stream)); + checkCUDA(cudaDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%ld, ", host_ptr[i]); } fclose(tensor_file); @@ -608,6 +691,9 @@ template __host__ void print_beam_tensor(int64_t const *ptr, template __host__ void save_tensor(float const *ptr, size_t rect, char const *file_name); +template __host__ void save_tensor(int32_t const *ptr, + size_t rect, + char const *file_name); template __host__ void save_tensor(int64_t const *ptr, size_t rect, char const *file_name); diff --git a/src/runtime/fftype.cc b/src/runtime/fftype.cc index 2b94f07999..819e6527e5 100644 --- a/src/runtime/fftype.cc +++ b/src/runtime/fftype.cc @@ -6,21 +6,26 @@ namespace FlexFlow { const LayerID LayerID::NO_ID = LayerID(); -LayerID::LayerID() : id(0), transformer_layer_id(MAX_NUM_TRANSFORMER_LAYERS) {} +LayerID::LayerID() + : id(0), transformer_layer_id(MAX_NUM_TRANSFORMER_LAYERS), model_id(0) {} -LayerID::LayerID(size_t _id, size_t _transformer_layer_id) - : id(_id), transformer_layer_id(_transformer_layer_id) { +LayerID::LayerID(size_t _id, size_t _transformer_layer_id, size_t _model_id) + : id(_id), transformer_layer_id(_transformer_layer_id), + model_id(_model_id) { assert(is_valid_id()); } bool LayerID::is_valid_id() const { - return (id >= LAYER_GUID_FIRST_VALID && id <= LAYER_GUID_LAST_VALID); + return (id >= LAYER_GUID_FIRST_VALID && id <= LAYER_GUID_LAST_VALID && + transformer_layer_id >= 0 && + transformer_layer_id < MAX_NUM_TRANSFORMER_LAYERS && model_id >= 0); } bool operator==(LayerID const &lhs, LayerID const &rhs) { // id should be sufficient to distinguish different layers if (lhs.id == rhs.id) { assert(lhs.transformer_layer_id == rhs.transformer_layer_id); + assert(lhs.model_id == rhs.model_id); } return lhs.id == rhs.id; } diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 408de57c54..6d33dd9f27 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -2291,6 +2291,7 @@ GraphOptimalViewSerialized Embedding *embed = (Embedding *)op; sez.serialize(embed->layer_guid.id); sez.serialize(embed->layer_guid.transformer_layer_id); + sez.serialize(embed->layer_guid.model_id); sez.serialize(embed->num_entries); sez.serialize(embed->out_channels); sez.serialize(embed->aggr); @@ -2301,6 +2302,7 @@ GraphOptimalViewSerialized MultiHeadAttention *attn = (MultiHeadAttention *)op; sez.serialize(attn->layer_guid.id); sez.serialize(attn->layer_guid.transformer_layer_id); + sez.serialize(attn->layer_guid.model_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_heads); sez.serialize(attn->qProjSize); @@ -2315,6 +2317,7 @@ GraphOptimalViewSerialized IncMultiHeadSelfAttention *attn = (IncMultiHeadSelfAttention *)op; sez.serialize(attn->layer_guid.id); sez.serialize(attn->layer_guid.transformer_layer_id); + sez.serialize(attn->layer_guid.model_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_q_heads); sez.serialize(attn->qProjSize); @@ -2339,6 +2342,7 @@ GraphOptimalViewSerialized (SpecIncMultiHeadSelfAttention *)op; sez.serialize(attn->layer_guid.id); sez.serialize(attn->layer_guid.transformer_layer_id); + sez.serialize(attn->layer_guid.model_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_q_heads); sez.serialize(attn->qProjSize); @@ -2360,6 +2364,7 @@ GraphOptimalViewSerialized (TreeIncMultiHeadSelfAttention *)op; sez.serialize(attn->layer_guid.id); sez.serialize(attn->layer_guid.transformer_layer_id); + sez.serialize(attn->layer_guid.model_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_q_heads); sez.serialize(attn->qProjSize); @@ -2639,11 +2644,12 @@ void FFModel::deserialize_graph_optimal_view( assert(num_inputs == 1); AggrMode aggr; int num_entries, out_channels; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; DataType data_type; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(num_entries); dez.deserialize(out_channels); dez.deserialize(aggr); @@ -2727,10 +2733,11 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_heads, k_dim, v_dim; float dropout; bool bias, add_bias_kv, add_zero_attn; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(embed_dim); dez.deserialize(num_heads); dez.deserialize(k_dim); @@ -2762,10 +2769,11 @@ void FFModel::deserialize_graph_optimal_view( bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling, offload, position_bias; DataType quantization_type; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(embed_dim); dez.deserialize(num_q_heads); dez.deserialize(k_dim); @@ -2812,10 +2820,11 @@ void FFModel::deserialize_graph_optimal_view( float dropout, scaling_factor; bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling, position_bias; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(embed_dim); dez.deserialize(num_q_heads); dez.deserialize(k_dim); @@ -2859,10 +2868,11 @@ void FFModel::deserialize_graph_optimal_view( bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling, offload, position_bias; DataType quantization_type; - size_t id, transformer_layer_id; + size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); - LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); dez.deserialize(embed_dim); dez.deserialize(num_q_heads); dez.deserialize(k_dim); diff --git a/src/runtime/hip_helper.cpp b/src/runtime/hip_helper.cpp index 1f27dc15e7..fb94135c8f 100644 --- a/src/runtime/hip_helper.cpp +++ b/src/runtime/hip_helper.cpp @@ -247,6 +247,112 @@ __host__ void checkCUDA(hipHostFree(host_ptr)); } +template <> +__host__ void + save_tensor(float const *ptr, size_t num_elements, char const *file_name) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + float *host_ptr; + checkCUDA(hipHostMalloc(&host_ptr, + sizeof(float) * num_elements, + hipHostMallocPortable | hipHostMallocMapped)); + checkCUDA(hipMemcpyAsync(host_ptr, + ptr, + sizeof(float) * num_elements, + hipMemcpyDeviceToHost, + stream)); + checkCUDA(hipDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%.9f, ", host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(hipHostFree(host_ptr)); +} + +template <> +__host__ void + save_tensor(half const *ptr, size_t num_elements, char const *file_name) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + half *host_ptr; + checkCUDA(hipHostMalloc(&host_ptr, + sizeof(half) * num_elements, + hipHostMallocPortable | hipHostMallocMapped)); + checkCUDA(hipMemcpyAsync(host_ptr, + ptr, + sizeof(half) * num_elements, + hipMemcpyDeviceToHost, + stream)); + checkCUDA(hipDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%.9f, ", (float)host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(hipHostFree(host_ptr)); +} + +template <> +__host__ void save_tensor(int32_t const *ptr, + size_t num_elements, + char const *file_name) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + int32_t *host_ptr; + checkCUDA(hipHostMalloc(&host_ptr, + sizeof(int32_t) * num_elements, + hipHostMallocPortable | hipHostMallocMapped)); + checkCUDA(hipMemcpyAsync(host_ptr, + ptr, + sizeof(int32_t) * num_elements, + hipMemcpyDeviceToHost, + stream)); + checkCUDA(hipDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%d, ", host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(hipHostFree(host_ptr)); +} + +template <> +__host__ void save_tensor(int64_t const *ptr, + size_t num_elements, + char const *file_name) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + int64_t *host_ptr; + checkCUDA(hipHostMalloc(&host_ptr, + sizeof(int64_t) * num_elements, + hipHostMallocPortable | hipHostMallocMapped)); + checkCUDA(hipMemcpyAsync(host_ptr, + ptr, + sizeof(int64_t) * num_elements, + hipMemcpyDeviceToHost, + stream)); + checkCUDA(hipDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%ld, ", host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(hipHostFree(host_ptr)); +} + template __host__ T *download_tensor(T const *ptr, size_t num_elements) { // device synchronize to make sure the data are ready @@ -493,6 +599,17 @@ template __host__ void template __host__ void print_tensor(half const *ptr, size_t rect, char const *prefix); +template __host__ void + save_tensor(float const *ptr, size_t rect, char const *file_name); +template __host__ void save_tensor(int32_t const *ptr, + size_t rect, + char const *file_name); +template __host__ void save_tensor(int64_t const *ptr, + size_t rect, + char const *file_name); +template __host__ void + save_tensor(half const *ptr, size_t rect, char const *file_name); + template __host__ float *download_tensor(float const *ptr, size_t num_elements); template __host__ half *download_tensor(half const *ptr, diff --git a/src/runtime/layer.cc b/src/runtime/layer.cc index d2473f4b2b..8f33f6db87 100644 --- a/src/runtime/layer.cc +++ b/src/runtime/layer.cc @@ -17,7 +17,8 @@ Layer::Layer(FFModel *model, const Tensor _input4) : op_type(_otype), data_type(_dtype), layer_guid(model->layer_global_guid++, - model->current_transformer_layer_id), + model->current_transformer_layer_id, + model->model_id), numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs) { std::string pcname; if (_name == nullptr) { @@ -52,7 +53,8 @@ Layer::Layer(FFModel *model, Tensor const *_tensors) : op_type(_otype), data_type(_dtype), layer_guid(model->layer_global_guid++, - model->current_transformer_layer_id), + model->current_transformer_layer_id, + model->model_id), numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs) { std::string pcname; if (_name == nullptr) { diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 3fa201e7ab..92f0cff472 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -122,7 +122,8 @@ Op::Op(FFModel &model, const ParallelTensor _input4) : op_type(_otype), data_type(_dtype), op_guid(model.op_global_guid++), numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs), - profiling(model.config.profiling) { + profiling(model.config.profiling), + inference_debugging(model.config.inference_debugging) { for (int i = 0; i < MAX_NUM_INPUTS; i++) { inputs[i] = NULL; } @@ -167,7 +168,8 @@ Op::Op(FFModel &model, ParallelTensor const *_inputs) : op_type(_otype), data_type(_dtype), op_guid(model.op_global_guid++), numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs), - profiling(model.config.profiling) { + profiling(model.config.profiling), + inference_debugging(model.config.inference_debugging) { std::string pcname; if (_name == NULL) { pcname = get_operator_type_name(op_type); @@ -1462,7 +1464,8 @@ bool Op::get_weight_parameter(TNParameter tnp, return true; } -OpMeta::OpMeta(FFHandler _handle) : handle(_handle), profiling(false) { +OpMeta::OpMeta(FFHandler _handle) + : handle(_handle), profiling(false), inference_debugging(false) { for (int i = 0; i < MAX_NUM_INPUTS; i++) { trainableInputs[i] = true; } @@ -1475,6 +1478,7 @@ OpMeta::OpMeta(FFHandler _handle) : handle(_handle), profiling(false) { for (int i = 0; i < MAX_NUM_OUTPUTS; i++) { output_type[i] = DT_NONE; } + decoding_step = 0; } OpMeta::OpMeta(FFHandler _handle, Op const *op) : OpMeta(_handle) { @@ -1487,6 +1491,7 @@ OpMeta::OpMeta(FFHandler _handle, Op const *op) : OpMeta(_handle) { for (int i = 0; i < op->numOutputs; i++) { output_type[i] = op->outputs[i]->data_type; } + decoding_step = 0; } FFRuntime::FFRuntime(FFConfig &config) { @@ -1530,6 +1535,8 @@ FFRuntime::FFRuntime(FFConfig &config) { FFRuntime *ffruntime_singleton = nullptr; +int FFModel::model_counter = 0; + FFModel::FFModel(FFConfig &_config, bool cpu_offload) : op_global_guid(OP_GUID_FIRST_VALID), layer_global_guid(LAYER_GUID_FIRST_VALID), @@ -1570,6 +1577,7 @@ FFModel::FFModel(FFConfig &_config, bool cpu_offload) for (int idx = 0; idx < config.workersPerNode * config.numNodes; idx++) { handlers[idx] = ffruntime_singleton->handlers[idx]; } + model_id = model_counter++; } void FFModel::clear_graph_search_cache() { @@ -3969,6 +3977,7 @@ struct DefaultConfig { // const static int iterations = 1; const static int batchSize = 64; const static bool profiling = false; + const static bool inference_debugging = false; constexpr static float learningRate = 0.01f; constexpr static float weightDecay = 0.0001f; const static size_t workSpaceSize = (size_t)128 * 1024 * 1024; // 128 MB @@ -4003,6 +4012,7 @@ FFConfig::FFConfig() { // iterations = DefaultConfig::iterations; batchSize = DefaultConfig::batchSize; profiling = DefaultConfig::profiling; + inference_debugging = DefaultConfig::inference_debugging; learningRate = DefaultConfig::learningRate; weightDecay = DefaultConfig::weightDecay; workSpaceSize = DefaultConfig::workSpaceSize; @@ -4188,6 +4198,10 @@ void FFConfig::parse_args(char **argv, int argc) { profiling = true; continue; } + if (!strcmp(argv[i], "--inference-debugging")) { + inference_debugging = true; + continue; + } if (!strcmp(argv[i], "--allow-tensor-op-math-conversion")) { allow_tensor_op_math_conversion = true; continue; diff --git a/src/runtime/operator.cc b/src/runtime/operator.cc index 08b1af8ca5..0b3813f41c 100644 --- a/src/runtime/operator.cc +++ b/src/runtime/operator.cc @@ -3,6 +3,14 @@ #include "flexflow/simulator.h" #include +#include +#include +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +#include "flexflow/utils/cuda_helper.h" +#else +#include "flexflow/utils/hip_helper.h" +#endif + namespace FlexFlow { size_t Op::get_untyped_params_hash() const { @@ -17,4 +25,105 @@ size_t Op::get_params_hash() const { get_operator_type_name(this->op_type)); } +/*static*/ +void Op::save_inference_tensors_to_file( + OpMeta *m, + int shard_id, + BatchConfig const *bc, + std::vector input_tensors, + std::vector weight_tensors, + std::vector output_tensors) { + + // Check if output directory exists, and create it if it does not + char const *folder_path = "./inference_tensors"; + struct stat st = {0}; + if (stat(folder_path, &st) == -1) { + // Directory does not exist, create it + mkdir(folder_path, 0700); + } + // output base filepath, shared by all tensors from the same operator + std::string base_filepath = + "./inference_tensors/model_" + std::to_string(m->layer_guid.model_id) + + "_decoding-step_" + std::to_string(m->decoding_step) + "_layer-num_" + + std::to_string(m->layer_guid.transformer_layer_id) + "_layer-name_" + + m->op_name + "_shard-id_" + std::to_string(shard_id); + // save batch config, if passed + if (bc != nullptr) { + bc->save_to_file(base_filepath + "_batch-config"); + } + // save all inputs + for (int i = 0; i < input_tensors.size(); i++) { + std::string filename = base_filepath + "_input_" + std::to_string(i); + if (input_tensors[i].data_type == DT_FLOAT) { + save_tensor(input_tensors[i].get_float_ptr(), + input_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (input_tensors[i].data_type == DT_HALF) { + save_tensor(input_tensors[i].get_half_ptr(), + input_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (input_tensors[i].data_type == DT_INT32) { + save_tensor(input_tensors[i].get_int32_ptr(), + input_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (input_tensors[i].data_type == DT_INT64) { + save_tensor(input_tensors[i].get_int64_ptr(), + input_tensors[i].domain.get_volume(), + filename.c_str()); + } else { + assert(false && "Tensor data type not supported"); + } + } + // only dump the weights once + if (m->decoding_step == 0) { + for (int i = 0; i < weight_tensors.size(); i++) { + std::string filename = base_filepath + "_weight_" + std::to_string(i); + if (weight_tensors[i].data_type == DT_FLOAT) { + save_tensor(weight_tensors[i].get_float_ptr(), + weight_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (weight_tensors[i].data_type == DT_HALF) { + save_tensor(weight_tensors[i].get_half_ptr(), + weight_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (weight_tensors[i].data_type == DT_INT32) { + save_tensor(weight_tensors[i].get_int32_ptr(), + weight_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (weight_tensors[i].data_type == DT_INT64) { + save_tensor(weight_tensors[i].get_int64_ptr(), + weight_tensors[i].domain.get_volume(), + filename.c_str()); + } else { + assert(false && "Tensor data type not supported"); + } + } + } + // save all outputs + for (int i = 0; i < output_tensors.size(); i++) { + std::string filename = base_filepath + "_output_" + std::to_string(i); + if (output_tensors[i].data_type == DT_FLOAT) { + save_tensor(output_tensors[i].get_float_ptr(), + output_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (output_tensors[i].data_type == DT_HALF) { + save_tensor(output_tensors[i].get_half_ptr(), + output_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (output_tensors[i].data_type == DT_INT32) { + save_tensor(output_tensors[i].get_int32_ptr(), + output_tensors[i].domain.get_volume(), + filename.c_str()); + } else if (output_tensors[i].data_type == DT_INT64) { + save_tensor(output_tensors[i].get_int64_ptr(), + output_tensors[i].domain.get_volume(), + filename.c_str()); + } else { + assert(false && "Tensor data type not supported"); + } + } + // increase count of decoding steps + m->decoding_step++; +} + }; // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/tree_verify_batch_config.cc b/src/runtime/tree_verify_batch_config.cc index 9efa06a2d5..cb68ecc5f1 100644 --- a/src/runtime/tree_verify_batch_config.cc +++ b/src/runtime/tree_verify_batch_config.cc @@ -14,6 +14,7 @@ */ #include "flexflow/batch_config.h" +#include "flexflow/request_manager.h" #include "legion.h" #include #include @@ -30,55 +31,72 @@ InferenceMode TreeVerifyBatchConfig::get_mode() const { return TREE_VERIFY_MODE; } -void TreeVerifyBatchConfig::print() const { - std::cout << "@@@@@@@@@@@@@@ TreeVerifyBatchConfig (mode " << get_mode() - << ") @@@@@@@@@@@@@@" << std::endl; - std::cout << "Max number of requests: " << max_requests_per_batch() - << std::endl; - std::cout << "Max number of tokens: " << max_tokens_per_batch() << std::endl; - std::cout << "Number of tokens: " << num_tokens << std::endl; - std::cout << "Number of requests: " << num_active_requests() << std::endl; - // std::cout << "Cached results: " << cached_results << std::endl; +std::ostream &operator<<(std::ostream &os, TreeVerifyBatchConfig const &bc) { + os << "@@@@@@@@@@@@@@ TreeVerifyBatchConfig (mode " << bc.get_mode() + << ") @@@@@@@@@@@@@@" << std::endl; + // Max values + os << "Max number of requests: " << bc.max_requests_per_batch() << std::endl; + os << "Max number of tokens: " << bc.max_tokens_per_batch() << std::endl; + os << "Max sequence length: " << bc.max_sequence_length() << std::endl; + // Current values + os << "Number of tokens: " << bc.num_active_tokens() << std::endl; + os << "Number of requests: " << bc.num_active_requests() << std::endl; + os << "Number of tokens to commit: " << bc.num_tokens_to_commit << std::endl; - std::cout << "Per-request info:\n"; - for (int i = 0; i < max_requests_per_batch(); i++) { - if (!request_completed[i]) { - std::cout << " Request " << i << ":\n"; - std::cout << " Token start offset: " - << requestsInfo[i].token_start_offset << std::endl; - std::cout << " Number of tokens in batch: " - << requestsInfo[i].num_tokens_in_batch << std::endl; - std::cout << " GUID: " << requestsInfo[i].request_guid << std::endl; - std::cout << " Max sequence length: " - << requestsInfo[i].max_sequence_length << std::endl; - std::cout << " Request completed: " << request_completed[i] - << std::endl; + os << "Per-request info:\n"; + for (int i = 0; i < bc.max_requests_per_batch(); i++) { + if (!bc.request_completed[i]) { + os << " Request " << i << ":\n"; + os << " Token start offset: " << bc.requestsInfo[i].token_start_offset + << std::endl; + os << " Number of tokens in batch: " + << bc.requestsInfo[i].num_tokens_in_batch << std::endl; + os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl; + os << " Max sequence length: " + << bc.requestsInfo[i].max_sequence_length << std::endl; + os << " Request completed: " << bc.request_completed[i] << std::endl; + os << " Request running: " << bc.request_running[i] << std::endl; } } - std::cout << "Per-token info:\n"; - for (int i = 0; i < num_tokens; i++) { - std::cout << " Token " << i << ":\n"; - std::cout << " Absolute depth in request: " - << tokensInfo[i].abs_depth_in_request << std::endl; - std::cout << " Request index: " << tokensInfo[i].request_index - << std::endl; - std::cout << " Token id: " << tokensInfo[i].token_id << std::endl; + os << "Per-token info:\n"; + for (int i = 0; i < bc.num_tokens; i++) { + os << " Token " << i << ":\n"; + os << " Absolute depth in request: " + << bc.tokensInfo[i].abs_depth_in_request << std::endl; + os << " Request index: " << bc.tokensInfo[i].request_index << std::endl; + os << " Token id: " << bc.tokensInfo[i].token_id << std::endl; } - std::cout << "Tokens to commit info:\n"; - for (int i = 0; i < num_tokens_to_commit; i++) { - std::cout << " Token " << i << ":\n"; - std::cout << " token_index: " << committed_tokens[i].token_index - << std::endl; - std::cout << " request_index: " << committed_tokens[i].request_index - << std::endl; - std::cout << " token_depth: " << committed_tokens[i].token_depth - << std::endl; + os << "Tokens to commit info:\n"; + for (int i = 0; i < bc.num_tokens_to_commit; i++) { + os << " Token " << i << ":\n"; + os << " token_index: " << bc.committed_tokens[i].token_index + << std::endl; + os << " request_index: " << bc.committed_tokens[i].request_index + << std::endl; + os << " token_depth: " << bc.committed_tokens[i].token_depth + << std::endl; } - std::cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" - << std::endl; + os << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << std::endl; + return os; +} + +void TreeVerifyBatchConfig::print() const { + std::cout << *this << std::endl; +} + +void TreeVerifyBatchConfig::save_to_file(std::string const &filename) const { + std::ofstream outputFile(filename); + if (outputFile.is_open()) { + outputFile << *this << std::endl; + outputFile.close(); + } else { + std::cerr << "Error: Unable to open the batch config output file: " + << filename << std::endl; + assert(false); + } } }; // namespace FlexFlow diff --git a/tests/inference/python_test_configs/generate_configs.py b/tests/inference/python_test_configs/generate_configs.py index e683faa469..609e15de49 100644 --- a/tests/inference/python_test_configs/generate_configs.py +++ b/tests/inference/python_test_configs/generate_configs.py @@ -18,6 +18,7 @@ "use_4bit_quantization": False, "use_8bit_quantization": False, "profiling": False, + "inference_debugging": False, "fusion": True, } llm_configs = {