Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify BeamSearchBatchConfig to Support General Tree Structure. #13

Open
zikun-li opened this issue Mar 15, 2024 · 9 comments
Open

Modify BeamSearchBatchConfig to Support General Tree Structure. #13

zikun-li opened this issue Mar 15, 2024 · 9 comments
Assignees

Comments

@zikun-li
Copy link

zikun-li commented Mar 15, 2024

This issue is opened to discuss the plan to modify BeamSearchBatchConfig To support general tree structures. As this data structure is related to many on-going projects, we want to make sure that the change to it won't block others' development.

The current BeamSearchBatchConfig class is as follows:

class BeamSearchBatchConfig : public BatchConfig {
public:
  BeamSearchBatchConfig();
  BeamSearchBatchConfig(int model_id);
  BeamSearchBatchConfig(size_t beam_width, size_t target_iterations);
  BeamSearchBatchConfig(BeamSearchBatchConfig const &other, int model_id);
  InferenceMode get_mode() const;

  ~BeamSearchBatchConfig();

  friend std::ostream &operator<<(std::ostream &os,
                                  BeamSearchBatchConfig const &bc);
  void print() const;
  void save_to_file(std::string const &filename) const;
  bool done() const;
  int max_beam_depth_all_requests() const; // Need to remove
  int current_depth_all_requests() const;
  int get_speculative_request_num() const;

  size_t beam_width;
  size_t target_iterations;

  // how many requests is in speculative phase
  int speculative_request_num = 0;
  inline static int const MAX_BEAM_WIDTH = 3; // Need to remove
  inline static int const MAX_BEAM_DEPTH = 8; // Need to remove

  // maximum tree branches for a request
  inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 3; // Need to remove

  int model_id;

  struct BeamSearchPerRequestInfo {
    int beam_size;
    int current_depth = -1;
    int max_depth = MAX_BEAM_DEPTH;

    BatchConfig::TokenId
        tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
    float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
    int parent_id[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
    int sub_request_num;
  };

  struct BeamSearchPerTokenInfo {
    int sub_request_index;
  };

  BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS];
  BeamSearchPerTokenInfo
      beamTokenInfo[MAX_NUM_TOKENS +
                    MAX_SPEC_TREE_TOKEN_NUM * MAX_NUM_REQUESTS];

  int sub_requests[MAX_NUM_REQUESTS];

private:
  size_t current_iteration;
};

Several redundant member variables and methods are listed as follows. Please take a look at them since the redundancy of some of them may need further verify:

  1. beam_width: we need to remove this since we want to support general tree structures.
  2. target_iterations: not used.
  3. current_iteration: not used.
  4. sub_requests: not used.
  5. done: not used.
  6. MAX_BEAM_WIDTH: remove to support general tree structures.
  7. MAX_BEAM_DEPTH: remove to support general tree structures.
  8. MAX_SPECULATIVE_TREE_BRANCHES: remove to support general tree structures.
  9. BeamSearchPerRequestInfo::beam_size, BeamSearchPerRequestInfo::max_depth and BeamSearchPerRequestInfo::sub_request_num: remove to support general tree structures.

The proposed new data structure to substitute BeamSearchBatchConfig is as follows:

class TreeSearchBatchConfig : public BatchConfig {
public:
  TreeSearchBatchConfig();
  TreeSearchBatchConfig(int model_id);
  TreeSearchBatchConfig(TreeSearchBatchConfig const &other, int model_id);
  InferenceMode get_mode() const;

  ~TreeSearchBatchConfig();

  friend std::ostream &operator<<(std::ostream &os,
                                  TreeSearchBatchConfig const &bc);
  void print() const;
  void save_to_file(std::string const &filename) const;
  int current_depth_all_requests() const;
  int get_speculative_request_num() const;

  // how many requests is in speculative phase
  int speculative_request_num = 0;
  int model_id;

  struct TreeSearchPerRequestInfo {
    int current_depth = -1;

    std::vector<BatchConfig::TokenId> tokens_vec;
    std::vector<float> probs_vec;
    std::vector<int> parent_id_vec;
  };

  std::vector<TreeSearchPerRequestInfo> tree_requests_info;
};

All the arrays are replaced with vector, please let us know if that could cause a problem.

@zikun-li zikun-li changed the title Modify BeamTreeBatchConfig to Support General Tree Structure. Modify BeamSearchBatchConfig to Support General Tree Structure. Mar 16, 2024
@zwang86
Copy link
Contributor

zwang86 commented Mar 17, 2024

I remember that we cannot use vector here since this data struct will be sent to GPU. That's the reason we use preallocated arrays for everything. Correct me if I am wrong.

@zikun-li
Copy link
Author

I see. Thanks for pointing this out! If you still remember it, could you please find the code to pass the data structure to GPU?

@jiazhihao
Copy link
Contributor

The data structure is automatically memcpy'd from CPU's address space to GPU's address space. Therefore, using any stl data structure is problematic since std::vector is essentially a pointer to an array, and the memcpy only copies the pointer instead of the actual array from CPU's to GPU's address space. The issue is that you cannot access the vector's elements from a GPU thread.

@jiazhihao
Copy link
Contributor

You should replace vectors with arrays in the data structure

@zikun-li
Copy link
Author

Okay, that make sense. It seems that if we preallocate the array, we have to set the size of the arrays in TreeSearchPerRequestInfo to be equal to our budget, and do we have access to the function that passes this data structure to the GPU?

@zikun-li
Copy link
Author

zikun-li commented Mar 17, 2024

Does the following look good? Btw, which one is the budget, MAX_NUM_TOKENS or MAX_SPEC_TREE_TOKEN_NUM?

class TreeSearchBatchConfig : public BatchConfig {
public:
  TreeSearchBatchConfig();
  TreeSearchBatchConfig(int model_id);
  TreeSearchBatchConfig(TreeSearchBatchConfig const &other, int model_id);
  InferenceMode get_mode() const;

  ~TreeSearchBatchConfig();

  friend std::ostream &operator<<(std::ostream &os,
                                  TreeSearchBatchConfig const &bc);
  void print() const;
  void save_to_file(std::string const &filename) const;
  int current_depth_all_requests() const;
  int get_speculative_request_num() const;

  // how many requests is in speculative phase
  int speculative_request_num = 0;
  int model_id;

  struct TreeSearchPerRequestInfo {
    int current_depth = -1;

    BatchConfig::TokenId tokens_arr[BatchConfig::MAX_NUM_TOKENS];
    float probs_arr[BatchConfig::MAX_NUM_TOKENS];
    int parent_id_arr[BatchConfig::MAX_NUM_TOKENS];
  };

  TreeSearchPerRequestInfo tree_requests_info[MAX_NUM_REQUESTS];
};

@zwang86
Copy link
Contributor

zwang86 commented Mar 19, 2024

MAX_NUM_TOKEN is the max tokens in a batch, so I think the budget should be MAX_SPEC_TREE_TOKEN_NUM. @xinhaoc Can you confirm with that? Thank you.

@zikun-li
Copy link
Author

Thanks for the comment. Is MAX_SPEC_TREE_TOKEN_NUM the budget for a single token or the whole batch?

@zwang86
Copy link
Contributor

zwang86 commented Mar 22, 2024

It should be the budget for single request. The entire batch should be MAX_SPEC_TREE_TOKEN_NUM * MAX_NUM_REQUESTS

@lockshaw lockshaw transferred this issue from flexflow/flexflow-train Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants