Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TableSchema test #13

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{
"files.associations": {
"cctype": "cpp",
"cmath": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"array": "cpp",
"atomic": "cpp",
"*.tcc": "cpp",
"bitset": "cpp",
"compare": "cpp",
"concepts": "cpp",
"cstdint": "cpp",
"map": "cpp",
"exception": "cpp",
"algorithm": "cpp",
"functional": "cpp",
"iterator": "cpp",
"memory_resource": "cpp",
"string": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"initializer_list": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"mutex": "cpp",
"new": "cpp",
"ostream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"typeinfo": "cpp",
"strstream": "cpp",
"bit": "cpp",
"clocale": "cpp",
"complex": "cpp",
"condition_variable": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"set": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"list": "cpp",
"memory": "cpp",
"numbers": "cpp",
"optional": "cpp",
"ratio": "cpp",
"semaphore": "cpp",
"sstream": "cpp",
"stop_token": "cpp",
"string_view": "cpp",
"thread": "cpp",
"typeindex": "cpp",
"variant": "cpp",
"__bit_reference": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__locale": "cpp",
"__node_handle": "cpp",
"__threading_support": "cpp",
"__tree": "cpp",
"__verbose_abort": "cpp",
"ios": "cpp",
"locale": "cpp"
}
}
44 changes: 30 additions & 14 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ add_compile_options(-fPIC)

find_package(OpenMP REQUIRED)

add_subdirectory(include/DisComm)

add_subdirectory(thirdparty)

add_library(alaya
Expand All @@ -61,6 +63,8 @@ add_library(alaya
${CMAKE_CURRENT_SOURCE_DIR}/python/bindings.cpp
include/alaya/index/graph/nsglib/utils.h
include/alaya/index/quantizer/normal_quantizer.h
${CMAKE_CURRENT_SOURCE_DIR}/src/Database.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/Table.cpp
)
target_link_libraries(alaya
PUBLIC
Expand Down Expand Up @@ -95,29 +99,35 @@ else()
target_link_libraries(alaya PUBLIC ${BLAS_LIBRARIES})
endif()

add_subdirectory(python)
add_subdirectory(tests)

# add_executable(test_RQ test_RQ.cpp)
# target_link_libraries(test_RQ alaya)

# add_executable(test_main test_main.cpp)
# target_link_libraries(test_main alaya)

enable_testing()
add_executable(
distance_test
tests/distance_test.cpp
tests/graph_test.cpp
# tests/graph_test.cpp
)
target_link_libraries(
distance_test
alaya
GTest::gtest_main
)

add_executable(
pq_test
tests/pq_test.cpp
)
target_link_libraries(
pq_test
alaya
GTest::gtest_main
)
# add_executable(
# pq_test
# tests/pq_test.cpp
# )
# target_link_libraries(
# pq_test
# alaya
# GTest::gtest_main
# )

add_executable(
ivf_test
Expand All @@ -129,6 +139,10 @@ target_link_libraries(
GTest::gtest_main
)

add_executable(tableschema_test tests/tableschema_test.cpp)
target_link_libraries(tableschema_test alaya GTest::gtest_main)
target_link_libraries(tableschema_test SQLite::SQLite3)

add_executable(
hello_test
tests/hello_test.cpp
Expand All @@ -139,8 +153,10 @@ target_link_libraries(
GTest::gtest_main
)


include(GoogleTest)
gtest_discover_tests(distance_test)
gtest_discover_tests(pq_test)
gtest_discover_tests(ivf_test)
# gtest_discover_tests(distance_test)
gtest_discover_tests(hello_test)
gtest_discover_tests(tableschema_test)
# gtest_discover_tests(pq_test)
gtest_discover_tests(ivf_test)
37 changes: 37 additions & 0 deletions include/DisComm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
cmake_minimum_required(VERSION 3.0.0)
project(DisComm VERSION 0.1.0 LANGUAGES C CXX)

include(CTest)
enable_testing()

# ------------------------------------------------------------------------------
# MPI (https://cmake.org/cmake/help/v3.0/module/FindMPI.html)
# ------------------------------------------------------------------------------
# set(MPIEXEC_EXECUTABLE /usr/local/mpich-3.4.3/bin)
# set(MPI_HOME /usr/local/mpich-3.4.3)
# set(MPIEXEC_EXECUTABLE /usr/local/mpich-3.4.3/bin/mpiexec)
# include(/home/dongjiang/jd/rdma-rpc/cmake/FindMPI.cmake)
find_package(MPI REQUIRED)
if(MPI_FOUND)
message(STATUS "[Find MPI]: YES, and version is ${MPI_VERSION}")
message(STATUS "\t- The MPI Include - origin: ${MPI_INCLUDE_PATH}")

message(STATUS "\t- The MPI Include: ${MPI_INCLUDE_DIR}")

include_directories(${MPI_INCLUDE_PATH}) #如果不添加可能找不到<mpi.h>
message(STATUS "\t- The MPI Lib - origin: ${MPI_CXX_LIBRARIES}")

message(STATUS "\t- The MPI Lib: ${MPI_LIB}")
else(MPI_FOUND)
message(FATAL_ERROR "[Find MPI]: NO")
endif(MPI_FOUND)


find_package(BLAS)
add_compile_definitions(USE_BLAS)
add_executable(DisComm main.cpp)
target_link_libraries(DisComm PUBLIC alaya ${MPI_LIBRARIES} ${BLAS_LIBRARIES})

set(CPACK_PROJECT_NAME ${PROJECT_NAME})
set(CPACK_PROJECT_VERSION ${PROJECT_VERSION})
include(CPack)
154 changes: 154 additions & 0 deletions include/DisComm/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#include "main.hpp"
#include "mpi_env.hpp"
#include "message_queue.hpp"
#include <iostream>
#include "../alaya/utils/metric_type.h"
#include "../alaya/utils/io_utils.h"
#include <alaya/index/bucket/ivf.h>
#include <alaya/searcher/ivf_searcher.h>


int main(int argc, char** argv)
{
Env::initEnv(argc, argv);
std::cout << "Hello, from DisComm!\n";
if (serverId() == 0) {
std::string query_path = "/dataset/netflix/netflix_query.fvecs";
unsigned d_num, d_dim, q_num, q_dim, K=10;
float *data, *query;
query = alaya::AlignLoadVecs<float>(query_path.c_str(), q_num, q_dim);
uint32_t *gt_ids = nullptr;
float *gt_dists = nullptr;
size_t gt_num, gt_dim;
// q_num = 1;
uint32_t query_buffer_capacity = (q_num * q_dim * sizeof(float) + 2 * sizeof(size_t)) * 2;
uint32_t result_buffer_capacity = (K * 30 * q_num * sizeof(float) + 2 * sizeof(size_t)) * 1.2;

auto s = std::chrono::high_resolution_clock::now();
Message::MessageQueue<char> *msgq0 =
new Message::MessageQueue<char>(query_buffer_capacity, result_buffer_capacity);
msgq0->sendMessage<float>(1, query, q_num, q_dim, QUERY);
msgq0->sendMessage<float>(2, query, q_num, q_dim, QUERY);
msgq0->sendMessage<float>(3, query, q_num, q_dim, QUERY);
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> send_time = e - s;

uint32_t **result_id2 = nullptr;
float **distance2 = nullptr;
size_t result_num2, result_dim2 = 0;
auto before_recv3 = std::chrono::high_resolution_clock::now();

msgq0->recvMessage<uint32_t>(3, result_id2, result_num2, result_dim2, IDX);
msgq0->recvMessage<float>(3, distance2, result_num2, result_dim2, DIST);

auto after_recv3 = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> recv3_time = after_recv3 - before_recv3;

uint32_t **result_id1 = nullptr;
float **distance1 = nullptr;
size_t result_num1, result_dim1 = 0;

msgq0->recvMessage<uint32_t>(2, result_id1, result_num1, result_dim1, IDX);
msgq0->recvMessage<float>(2, distance1, result_num1, result_dim1, DIST);
auto after_recv2 = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> recv2_time = after_recv2 - after_recv3;

uint32_t **result_id0 = nullptr;
float **distance0 = nullptr;
size_t result_num0, result_dim0 = 0;

msgq0->recvMessage<uint32_t>(1, result_id0, result_num0, result_dim0, IDX);
msgq0->recvMessage<float>(1, distance0, result_num0, result_dim0, DIST);

std::vector<std::vector<uint32_t>> res_id(result_num0);
std::vector<std::vector<float>> res_dist(result_num0);


merge_to_vector1<uint32_t, float, uint32_t, float>(res_id, res_dist, result_id0, distance0, result_num0,
result_dim0);
merge_to_vector2<uint32_t, float, uint32_t, float>(res_id, res_dist, result_id1, distance1, result_num0,
result_dim0);
merge_to_vector3<uint32_t, float, uint32_t, float>(res_id, res_dist, result_id2, distance2, result_num0,
result_dim0);

sort_vector<uint32_t, float>(res_id, res_dist, K, q_num);

//= begin Lvec.size() loop----------------------------------------------
double best_recall = 0.0;
for (uint32_t test_id = 0; test_id < 30; test_id++)
{
if (test_id == 30 - 1)
for (int ids_index = 0; ids_index < K; ++ids_index)
printf("after sort result_id in Lvec %d, is: %d\n", test_id, res_id[test_id][ids_index]);
std::cout << std::endl;
// double recall = 0;
// if (calc_recall_flag)
// {
// //* 计算recall函数,返回值为recall,对每个L都计算一次。
// //* 传入参数为:q_num 查询个数、gt_ids, gt_dists, gt_dim,
// //* query_result_id[test_id]这个是一维数组首地址,recall_at 就是K,所以参数传递没有问题。
// recall = diskann::calculate_recall((uint32_t)q_num, gt_ids, gt_dists, (uint32_t)gt_dim,
// res_id[test_id].data(), K, K);
// // 在L个结果中,选出最佳recall
// best_recall = std::max(recall, best_recall);
// }
// if (calc_recall_flag)
// {
// diskann::cout << std::setw(16) << recall << std::endl;
// }
// else
// diskann::cout << std::endl;
} // end of Lvec.size()
auto dis_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> total_dis_time = dis_time - s;
std::cout << "----------------------total single time is: " << total_dis_time.count() << std::endl;
} // end of serverId()=0

else
{
// // try
// // {
// unsigned K = 10;
// uint32_t query_buffer_capacity = (1000 * 960 * sizeof(float) + 2 * sizeof(size_t)) * 2;
// uint32_t result_buffer_capacity = (K * 30 * 1000 * sizeof(float) + 2 * sizeof(size_t)) * 1.2;

// Message::MessageQueue<char> *msgq1 =
// new Message::MessageQueue<char>(result_buffer_capacity, query_buffer_capacity);
// size_t query_num, query_dim = 0;
// float* recv_query = nullptr;
// float* distance = new float[K];
// int64_t* result_id = new int64_t[K];

// msgq1->recvMessage(0, recv_query, query_num, query_dim, QUERY);

// std::string netflix_path = "/dataset/netflix";
// unsigned d_num, d_dim, q_num, q_dim;
// float* data = alaya::LoadVecs<float>(fmt::format("{}/netflix_base.fvecs", netflix_path).c_str(),
// d_num, d_dim);
// float* query = alaya::AlignLoadVecs<float>(
// fmt::format("{}/netflix_query.fvecs", netflix_path).c_str(), q_num, q_dim);
// fmt::println("d_num: {}, d_dim: {}, q_num: {}, q_dim: {}", d_num, d_dim, q_num, q_dim);
// unsigned bucket_num = 100;
// alaya::IVF<float> ivf(d_dim, alaya::MetricType::L2, bucket_num);

// ivf.BuildIndex(d_num, data);

// alaya::IvfSearcher<alaya::MetricType::L2, float> searcher(&ivf);

// searcher.SetNprobe(10);

// searcher.Search(query, q_dim, K, distance, result_id);

// for (auto i = 0; i < K; ++i) {
// fmt::println("distance: {}, result_id: {}", distance[i], result_id[i]);
// }

// msgq1->sendMessage<int64_t>(0, &result_id, 30, K * query_num, IDX);
// msgq1->sendMessage<float>(0, &distance, 30, K * query_num, DIST);
// delete[] distance;
// delete[] result_id;

} // end of serverId()=1

Env::endEnv();
}
Loading