Skip to content

Commit

Permalink
use ukkonen version (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro authored Jan 23, 2025
1 parent 563277f commit fd1e771
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,11 @@ if(NOT BUILD_LEGION_ONLY)
FetchContent_Declare(
suffix_decoding
GIT_REPOSITORY https://github.com/Snowflake-Labs/suffix-tree-decoding.git
GIT_TAG main # or a specific tag/commit hash
GIT_TAG ukkonen # or a specific tag/commit hash
)
FetchContent_MakeAvailable(suffix_decoding)
list(APPEND FLEXFLOW_INCLUDE_DIRS ${suffix_decoding_SOURCE_DIR}/src)
list(APPEND FLEXFLOW_SRC ${suffix_decoding_SOURCE_DIR}/src/suffix_decoding.cc)
list(APPEND FLEXFLOW_SRC ${suffix_decoding_SOURCE_DIR}/src/suffix_tree.cc)

set(FLEXFLOW_CPP_DRV_SRC
${FLEXFLOW_ROOT}/src/runtime/cpp_driver.cc)
Expand Down
8 changes: 4 additions & 4 deletions cmake/nccl.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ if(NCCL_LIBRARY AND NCCL_INCLUDE_DIR)
string(REGEX MATCH "([0-9]+)" NCCL_MAJOR ${NCCL_VERSION_DEFINES})
string(REGEX MATCH "([0-9]+)" NCCL_MINOR ${NCCL_VERSION_DEFINES2})
set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}")
if(NCCL_VERSION VERSION_LESS 2.23)
set(NCCL_OLD TRUE)
else()
# if(NCCL_VERSION VERSION_LESS 2.23)
# set(NCCL_OLD TRUE)
# else()
set(NCCL_OLD FALSE)
endif()
# endif()
message(STATUS "Found NCCL version: ${NCCL_VERSION}")
else()
message(WARNING "NCCL header not found, unable to determine version")
Expand Down
6 changes: 3 additions & 3 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/utils/file_loader.h"
#include "suffix_decoding.h"
#include "suffix_tree.h"
#include <condition_variable>
#include <future>
#include <mutex>
Expand Down Expand Up @@ -150,7 +150,7 @@ struct Request {
std::vector<BatchConfig::TokenId> tokens;

// TokenTree speculative_token_tree;
SuffixTree *prompt_tree = nullptr;
SuffixTree<int> *prompt_tree = nullptr;
std::vector<int> suffix_decoding_best_token_ids;
std::vector<int> suffix_decoding_best_parents;
float suffix_decoding_best_score = 0.0f;
Expand Down Expand Up @@ -529,7 +529,7 @@ class RequestManager {
MatchingStrategy suffix_tree_matching_strategy;
float suffix_tree_max_spec_factor = -1.0f;
bool suffix_tree_online_tree_update = true;
SuffixTree *suffix_tree = nullptr;
SuffixTree<int> *suffix_tree = nullptr;

// Background server handler
Legion::Future background_server_handler;
Expand Down
2 changes: 1 addition & 1 deletion inference/suffix_decoding/suffix_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* limitations under the License.
*/

#include "suffix_decoding.h"
#include "suffix_tree.h"
#include "flexflow/inference.h"
#include "flexflow/request_manager.h"
#include "models/falcon.h"
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ void RequestManager::init_suffix_tree(std::string const &trace_filepath,
std::vector<int> encoded = this->tokenizer_->Encode(entry.response);
training_dataset.push_back(encoded);
}
suffix_tree = new SuffixTree(training_dataset, suffix_tree_max_depth);
suffix_tree = new SuffixTree<int>(training_dataset);
}

RequestManager::RequestGuid
Expand Down Expand Up @@ -820,7 +820,7 @@ void RequestManager::insert_completed_request_into_suffix_tree(
request.tokens.end() - request.decode_length(), request.tokens.end());
assert(output_tokens.size() == request.decode_length());
if (output_tokens.size() > 0) {
suffix_tree->insert(output_tokens);
suffix_tree->add_entry(output_tokens);
}
long long int end_time = Realm::Clock::current_time_in_microseconds();
assert(profiling.tree_operation_step_times.size() > 0);
Expand Down Expand Up @@ -1107,7 +1107,7 @@ bool RequestManager::update_llm_prefill_results(InferenceResult const &result) {
assert(request->prompt_tree == nullptr && "Prompt tree was already initialized");
assert(this->suffix_tree_max_depth > 0 && "Invalid max depth for suffix tree");
assert(request->tokens.size() > 0 && "Attempting to create prompt tree for empty request");
request->prompt_tree = new SuffixTree({request->tokens}, this->suffix_tree_max_depth);
request->prompt_tree = new SuffixTree<int>({request->tokens});
}

if (decoding_mode == SPECULATIVE_DECODING) {
Expand Down

0 comments on commit fd1e771

Please sign in to comment.