diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 8aa69a3cad..d2fbd6219a 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -46,9 +46,8 @@ class BatchConfig { void print() const; virtual InferenceMode get_mode() const; static BatchConfig const *from_future(BatchConfigFuture const &future); - static int const MAX_NUM_REQUESTS = 4; + static int const MAX_NUM_REQUESTS = 7; static int const MAX_NUM_TOKENS = 64; - static int const MAX_PROMPT_LENGTH = 62; static int const MAX_SEQ_LENGTH = 256; // These are set by update diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 5489c9b06d..6f0a1f3851 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -1144,6 +1144,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( << std::endl; } new_bc.num_tokens_to_commit++; + request.llm_cache_size++; } } } @@ -1255,6 +1256,19 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( "Exceeding the space available in the TreeVerify batch"); break; } + + if (new_bc.num_tokens + request.llm_cache_size >= request.initial_len) { + // launch the request into running phase after loading all prompt + request.status = Request::RUNNING; + new_bc.request_running[i] = true; + + std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " + << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + + dfs_tree_inputs[guid] = + std::vector>{std::make_pair( + request.tokens.back(), request.tokens.size() - 1)}; + } } else { // launch the request into running phase after loading all prompt if (BatchConfig::MAX_NUM_TOKENS - new_bc.num_tokens > 0) { request.status = Request::RUNNING;