Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify BatchConfig #9

Open
zikun-li opened this issue Apr 12, 2024 · 11 comments
Open

Simplify BatchConfig #9

zikun-li opened this issue Apr 12, 2024 · 11 comments
Assignees

Comments

@zikun-li
Copy link

zikun-li commented Apr 12, 2024

As we plan to move some states from the BatchConfig to the RequestManager, some fields in BatchConfig are rendered redundant. The following are the data members of the current BatchConfig.

class BatchConfig {
public:
  static int const MAX_NUM_REQUESTS = 64;
  static int const MAX_NUM_TOKENS = 1024;
  static int const MAX_SPEC_TREE_TOKEN_NUM = 64;
  int num_tokens;
  int num_generation_tokens;
  struct PerRequestInfo {
    int first_token_depth_in_request;
    int first_token_offset_in_batch;
    int num_tokens_in_batch;
    int max_sequence_length;
    int batch_config_request_id;
    bool prompt_phase = false;
    RequestGuid request_guid;
  };
  struct PerTokenInfo {
    int abs_depth_in_request;
    int request_index;
    TokenId token_id;
  };
  BitMask causalMask[MAX_NUM_REQUESTS];
  PerRequestInfo requestsInfo[MAX_NUM_REQUESTS];
  PerTokenInfo tokensInfo[MAX_NUM_TOKENS];

  bool request_completed[MAX_NUM_REQUESTS];
  bool request_running[MAX_NUM_REQUESTS];
};

The following fields seem redundant:

  1. num_generation_tokens: in previous versions, the prefilling and decoding phase of the small model is mixed, and this field is used to record the number of tokens generated by the small model in the decoding phase. Now we separate the prefilling and decoding phase of the small model, all tokens are generated in the decoding phase, and the number can be found in num_tokens.
  2. PerRequestInfo.max_sequence_length: I think this should be a field of RequestManager.
  3. PerRequestInfo.batch_config_request_id: we can store a mapping from the index of a request in the batch to the guid of the request.
  4. request_completed: This is stored in RequestManager.
  5. request_running: This is stored in RequestManager.

There are also some redundancies in the current TreeSearchBatchConfig , the following are the current data members:

class TreeSearchBatchConfig : public BatchConfig {
public:
  inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 3;
  inline static int const MAX_TREE_DEPTH = 16;
  int speculative_request_num = 0;
  int current_depth = 0;
  int model_id;
  struct TreeSearchPerRequestInfo {
    int num_tokens_at_depth = 0;
  };
  TreeSearchPerRequestInfo tree_requests_info[MAX_NUM_REQUESTS];
};

As the base class BatchConfig already has a field in PerRequestInfo called num_tokens_in_batch, we don't need the struct TreeSearchPerRequestInfo, because the only field in it, num_tokens_at_depth is equivalent to num_tokens_in_batch.

Please let me know if any of the fields listed above is not redundant. Otherwise, let's remove them. We can discuss to move other fields into RequestManager if you have other suggestions.

@zikun-li
Copy link
Author

zikun-li commented Apr 12, 2024

The new version of the BatchConfig:

class BatchConfig {
public:
  static int const MAX_NUM_REQUESTS = 64;
  static int const MAX_NUM_TOKENS = 1024;
  static int const MAX_SPEC_TREE_TOKEN_NUM = 64;
  int num_tokens;
  int num_available requests; // Added
  struct PerRequestInfo {
    int first_token_index_in_request;
    int first_token_offset_in_batch;
    int num_tokens_in_batch;
  };
  struct PerTokenInfo {
    int abs_index_in_request;
    int request_index;
    TokenId token_id;
  };
  BitMask causalMask[MAX_NUM_REQUESTS];
  PerRequestInfo requestsInfo[MAX_NUM_REQUESTS];
  PerTokenInfo tokensInfo[MAX_NUM_TOKENS];

  bool request_available[MAX_NUM_REQUESTS]; // Name changed
};

Field to remove:

  1. num_generation_tokens.
  2. PerRequestInfo.max_sequence_length.
  3. PerRequestInfo.batch_config_request_id.
  4. PerRequestInfo.prompt_phase.
  5. PerRequestInfo.request_guid.
  6. request_running.

Change name:

  1. request_complete -> request_available.

Add:

  1. A field indicating number of available requests in a BatchConfig.

The new version of TreeSearchBatchConfig:

class TreeSearchBatchConfig : public BatchConfig {
public:
  inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 3;
  inline static int const MAX_TREE_DEPTH = 16;
  int current_depth = 0;
  int model_id;
};

Field to remove:

  1. tree_requests_info.
  2. speculative_request_num: because we decide to save this in BatchConfig.

@chenzhuofu
Copy link
Contributor

chenzhuofu commented Apr 18, 2024

I would take the kernel part development.

Got a question,
I found some usage of batch_config_request_id in .cu code, and now it is replaced by RequestManager::guid_of_requests. Should I assume that request_manager will take the duty of copying guid_of_requests field into IncMultiHeadSelfAttentionMeta (just like token_infos and request_infos)?

@zikun-li

@chenzhuofu chenzhuofu self-assigned this Apr 18, 2024
@zikun-li
Copy link
Author

batch_config_request_id is a compact list of indices of the available requests. In the new BatchConfig class we have an equivalent field request_available, which is a indicator field of whether different requests are available. E.g. If we have a batch with 4 slots where slot 0, 2 and 3 have available requests, batch_config_request_id will be the compact list [0, 2, 3], and request_available will be [true, false, true, true]. Another difference is that in the previous BatchConfig, batch_config_request_id is scattered in different PerRequestInfo in requestsInfo. In the above example, requestsInfo[0].batch_config_request_id would be 0, and similarly requestsInfo[1].batch_config_request_id = 2, requestsInfo[2].batch_config_request_id = 3. I believe we can adept the implementation to using request_available instead of batch_config_request_id since they contain virtually the same information. @aetiurf

@chenzhuofu
Copy link
Contributor

chenzhuofu commented Apr 18, 2024

Yes, we do can use request_available to calculate the $i$th available request index. However my concern is here: https://github.com/flexflow/FlexFlow/blob/inference/src/ops/inc_multihead_self_attention.cu#L83

If we switch to request_available we will obtain the index as:

int index = 0;
while(request_idx) {
    while(!request_available[index]){
        ++index;
    }
    --request_idx;
}

And this would result into $O(N)$ caclution in kernel thread (which previously $O(1)$).

@xinhaoc
Copy link
Contributor

xinhaoc commented Apr 18, 2024

I agree with the idea of keeping this batch_config_request_id , at here we launch num_heads * num_requests(which equals to num_generation_tokens in decoding phase) threadblocks, if switch to request_available and not including the for-loop zhuofu mentioned above , it means we need to launch num_heads * num_max_requests threadblocks, which results in a larger kernel.

@zikun-li
Copy link
Author

That's a valid concern. But I don't think we need to launch num_heads * num_max_requests kernels, because in the new verison of BatchConfig we have a field BatchConfig.num_available_requests that indicates how many available requests are there in the batch. Does this solves the issue? @xinhaoc

@aetiurf For this part, it seems we don't need to worry too much about this since N is typically very small. We can also maintain a list like batch_config_request_id, but I think that might be a little bit redundant. What do you think? @zwang86

@chenzhuofu
Copy link
Contributor

That's a valid concern. But I don't think we need to launch num_heads * num_max_requests kernels, because in the new verison of BatchConfig we have a field BatchConfig.num_available_requests that indicates how many available requests are there in the batch. Does this solves the issue? @xinhaoc

@aetiurf For this part, it seems we don't need to worry too much about this since N is typically very small. We can also maintain a list like batch_config_request_id, but I think that might be a little bit redundant. What do you think? @zwang86

Ah, I mix up the MAX_NUM_REQUESTS = 1024, for 64 it should be no much performance degradation. Then I have no problem here.

@zikun-li
Copy link
Author

Actually it could be even smaller (e.g. 8 or 16).

@chenzhuofu
Copy link
Contributor

Yes, we do can use request_available to calculate the $i$th available request index. However my concern is here: https://github.com/flexflow/FlexFlow/blob/inference/src/ops/inc_multihead_self_attention.cu#L83

If we switch to request_available we will obtain the index as:

int index = 0;
while(request_idx) {
    while(!request_available[index]){
        ++index;
    }
    --request_idx;
}

And this would result into O(N) caclution in kernel thread (which previously $O(1)$).

ANNOTATION:
Using code snippet above will make GPU thread totally divergent, so remember adding a trailing __syncthreads().

@chenzhuofu
Copy link
Contributor

chenzhuofu commented Apr 22, 2024

Is speculative_request_num now BatchConfig::num_available_request ?
And is BeamSearchPerRequestInfo::sub_request_num now the constant MAX_SPECULATIVE_TREE_BRANCHES?

Please help me verify this :P

@zikun-li
Copy link
Author

In prefilling phase, BatchConfig::num_available_request should be 1, while in the decoding phase, BatchConfig::num_available_requests should be the same as speculative_request_num.

@lockshaw lockshaw transferred this issue from flexflow/flexflow-train Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants