diff --git a/cpp/HierarchicalNSW.hpp b/cpp/HierarchicalNSW.hpp index b0acdea..74663a4 100644 --- a/cpp/HierarchicalNSW.hpp +++ b/cpp/HierarchicalNSW.hpp @@ -1,370 +1,424 @@ #pragma once -#include -#include -#include #include "RnHostClass.hpp" #include "RnJsiContext.hpp" #include "helper.h" #include "hnswlib/hnswlib.h" +#include +#include namespace RNHnswlib { - void normalize_vector(float* data) { - float norm = 0.0f; - for (int i = 0; i < sizeof(data); i++) - norm += data[i] * data[i]; - norm = 1.0f / (sqrtf(norm) + 1e-30f); - for (int i = 0; i < sizeof(data); i++) - data[i] *= norm; +void normalize_vector(float *data) { + float norm = 0.0f; + for (int i = 0; i < sizeof(data); i++) + norm += data[i] * data[i]; + norm = 1.0f / (sqrtf(norm) + 1e-30f); + for (int i = 0; i < sizeof(data); i++) + data[i] *= norm; +} + +class HierarchicalNSW : public RnJSI::HostClass { +public: + // constructor(spaceName: SpaceName, numDimensions: number); + JSI_HOST_OBJECT_CONSTRUCTOR(HierarchicalNSW) { + JSI_ASSERT(runtime, argc == 2, "Expected 2 arguments"); + JSI_ASSERT(runtime, args[0].isString(), "Expected string"); + JSI_ASSERT(runtime, args[1].isNumber(), "Expected number"); + + std::string spaceName = args[0].getString(runtime).utf8(runtime); + int numDimensions = args[1].getNumber(); + + if (spaceName == "l2") { + space_ = std::make_unique( + static_cast(numDimensions)); + } else if (spaceName == "ip") { + space_ = std::make_unique( + static_cast(numDimensions)); + } else if (spaceName == "cosine") { + space_ = std::make_unique( + static_cast(numDimensions)); + normalize_ = true; + } else { + throw facebook::jsi::JSError(runtime, "Invalid space name"); + } } - class HierarchicalNSW: public RnJSI::HostClass { - public: - // constructor(spaceName: SpaceName, numDimensions: number); - JSI_HOST_OBJECT_CONSTRUCTOR(HierarchicalNSW) { - JSI_ASSERT(runtime, argc == 2, "Expected 2 arguments"); - JSI_ASSERT(runtime, args[0].isString(), "Expected string"); + // initIndex(maxElementsOrOpts: number | object, m?: number, efConstruction?: + // number, randomSeed?: number, allowReplaceDeleted?: boolean): void; + JSI_HOST_FUNCTION(initIndex) { + JSI_ASSERT(runtime, argc >= 1, "Expected at least 1 argument"); + JSI_ASSERT(runtime, args[0].isNumber() || args[0].isObject(), + "Expected number or object"); + if (argc >= 2) JSI_ASSERT(runtime, args[1].isNumber(), "Expected number"); - - std::string spaceName = args[0].getString(runtime).utf8(runtime); - int numDimensions = args[1].getNumber(); - - if (spaceName == "l2") { - space_ = std::make_unique(static_cast(numDimensions)); - } else if (spaceName == "ip") { - space_ = std::make_unique(static_cast(numDimensions)); - } else if (spaceName == "cosine") { - space_ = std::make_unique(static_cast(numDimensions)); - normalize_ = true; - } else { - throw facebook::jsi::JSError(runtime, "Invalid space name"); - } + if (argc >= 3) + JSI_ASSERT(runtime, args[2].isNumber(), "Expected number"); + if (argc >= 4) + JSI_ASSERT(runtime, args[3].isNumber(), "Expected number"); + if (argc >= 5) + JSI_ASSERT(runtime, args[4].isBool(), "Expected boolean"); + + const bool use_opts = args[0].isObject(); + + if (use_opts) + JSI_ASSERT(runtime, + args[0].getObject(runtime).hasProperty(runtime, "maxElements"), + "`maxElements` is required"); + + const uint32_t max_elements = use_opts + ? args[0] + .getObject(runtime) + .getProperty(runtime, "maxElements") + .getNumber() + : args[0].getNumber(); + const uint32_t m = + use_opts + ? JSI_GET_PROPERTY_OR_DEFAULT(args[0].getObject(runtime), runtime, + "m", 16) + .getNumber() + : JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 1, 16).getNumber(); + + const uint32_t ef_construction = + use_opts + ? JSI_GET_PROPERTY_OR_DEFAULT(args[0].getObject(runtime), runtime, + "efConstruction", 200) + .getNumber() + : JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 2, 200).getNumber(); + + const uint32_t random_seed = + use_opts + ? JSI_GET_PROPERTY_OR_DEFAULT(args[0].getObject(runtime), runtime, + "randomSeed", 100) + .getNumber() + : JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 3, 100).getNumber(); + + const bool allow_replace_deleted = + use_opts + ? JSI_GET_PROPERTY_OR_DEFAULT(args[0].getObject(runtime), runtime, + "allowReplaceDeleted", false) + .getBool() + : JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 4, false).getBool(); + + if (index_ != nullptr) + index_.reset(); + + try { + index_ = std::make_unique>( + space_.get(), max_elements, m, ef_construction, random_seed, + allow_replace_deleted); + } catch (const std::exception &e) { + throw facebook::jsi::JSError(runtime, e.what()); } - // initIndex(maxElementsOrOpts: number | object, m?: number, efConstruction?: number, randomSeed?: number, allowReplaceDeleted?: boolean): void; - JSI_HOST_FUNCTION(initIndex) { - JSI_ASSERT(runtime, argc >= 1, "Expected at least 1 argument"); - JSI_ASSERT(runtime, args[0].isNumber() || args[0].isObject(), "Expected number or object"); - if (argc >= 2) JSI_ASSERT(runtime, args[1].isNumber(), "Expected number"); - if (argc >= 3) JSI_ASSERT(runtime, args[2].isNumber(), "Expected number"); - if (argc >= 4) JSI_ASSERT(runtime, args[3].isNumber(), "Expected number"); - if (argc >= 5) JSI_ASSERT(runtime, args[4].isBool(), "Expected boolean"); - - const bool use_opts = args[0].isObject(); - - if (use_opts) - JSI_ASSERT(runtime, args[0].getObject(runtime).hasProperty(runtime, "maxElements"), "`maxElements` is required"); + return facebook::jsi::Value::undefined(); + } - const uint32_t max_elements = use_opts ? args[0].getObject(runtime).getProperty(runtime, "maxElements").getNumber() : args[0].getNumber(); - const uint32_t m = use_opts - ? JSI_GET_PROPERTY_OR_DEFAULT(args[0].getObject(runtime), runtime, "m", 16).getNumber() - : JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 1, 16).getNumber(); + // readIndex(filename: string, allowReplaceDeleted?: boolean): + // Promise; + JSI_HOST_FUNCTION(readIndex) { + JSI_ASSERT(runtime, argc >= 1, "Expected at least 1 argument"); + JSI_ASSERT(runtime, args[0].isString(), "Expected string"); + if (argc >= 2) + JSI_ASSERT(runtime, args[1].isBool(), "Expected boolean"); + + std::string filename = args[0].getString(runtime).utf8(runtime); + bool allow_replace_deleted = + JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 1, false).getBool(); + + return JSI_CREATE_PROMISE( + runtime, + { + if (index_ != nullptr) + index_.reset(); + + try { + index_ = std::make_unique>( + space_.get(), filename, false, 0, allow_replace_deleted); + resolve.call(runtime, true); + } catch (const std::exception &e) { + reject.call(runtime, e.what()); + } + }, + this, filename, allow_replace_deleted); + } - const uint32_t ef_construction = use_opts - ? JSI_GET_PROPERTY_OR_DEFAULT(args[0].getObject(runtime), runtime, "efConstruction", 200).getNumber() - : JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 2, 200).getNumber(); + // readIndexSync(filename: string, allowReplaceDeleted?: boolean): void; + JSI_HOST_FUNCTION(readIndexSync) { + JSI_ASSERT(runtime, argc >= 1, "Expected at least 1 argument"); + JSI_ASSERT(runtime, args[0].isString(), "Expected string"); + if (argc >= 2) + JSI_ASSERT(runtime, args[1].isBool(), "Expected boolean"); + + const std::string filename = args[0].getString(runtime).utf8(runtime); + const bool allow_replace_deleted = + JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 1, false).getBool(); + + if (index_ != nullptr) + index_.reset(); + + try { + index_ = std::make_unique>( + space_.get(), filename, false, 0, allow_replace_deleted); + } catch (const std::exception &e) { + throw facebook::jsi::JSError(runtime, e.what()); + } - const uint32_t random_seed = use_opts - ? JSI_GET_PROPERTY_OR_DEFAULT(args[0].getObject(runtime), runtime, "randomSeed", 100).getNumber() - : JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 3, 100).getNumber(); + return facebook::jsi::Value::undefined(); + } - const bool allow_replace_deleted = use_opts - ? JSI_GET_PROPERTY_OR_DEFAULT(args[0].getObject(runtime), runtime, "allowReplaceDeleted", false).getBool() - : JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 4, false).getBool(); + // writeIndex(filename: string): Promise; + JSI_HOST_FUNCTION(writeIndex) { + JSI_ASSERT(runtime, argc == 1 && args[0].isString(), "Expected string"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + + const std::string filename = args[0].getString(runtime).utf8(runtime); + + return JSI_CREATE_PROMISE( + runtime, + { + try { + index_->saveIndex(filename); + resolve.call(runtime); + } catch (const std::exception &e) { + reject.call(runtime, e.what()); + } + }, + this, filename); + } - if (index_ != nullptr) index_.reset(); + // writeIndexSync(filename: string): void; + JSI_HOST_FUNCTION(writeIndexSync) { + JSI_ASSERT(runtime, argc == 1 && args[0].isString(), "Expected string"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - try { - index_ = std::make_unique>(space_.get(), max_elements, m, ef_construction, random_seed, allow_replace_deleted); - } catch (const std::exception &e) { - throw facebook::jsi::JSError(runtime, e.what()); - } + const std::string filename = args[0].getString(runtime).utf8(runtime); - return facebook::jsi::Value::undefined(); + try { + index_->saveIndex(filename); + } catch (const std::exception &e) { + throw facebook::jsi::JSError(runtime, e.what()); } - // readIndex(filename: string, allowReplaceDeleted?: boolean): Promise; - JSI_HOST_FUNCTION(readIndex) { - JSI_ASSERT(runtime, argc >= 1, "Expected at least 1 argument"); - JSI_ASSERT(runtime, args[0].isString(), "Expected string"); - if (argc >= 2) JSI_ASSERT(runtime, args[1].isBool(), "Expected boolean"); - - std::string filename = args[0].getString(runtime).utf8(runtime); - bool allow_replace_deleted = JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 1, false).getBool(); - - return JSI_CREATE_PROMISE(runtime, { - if (index_ != nullptr) index_.reset(); - - try { - index_ = std::make_unique>(space_.get(), filename, false, 0, allow_replace_deleted); - resolve.call(runtime, true); - } catch (const std::exception &e) { - reject.call(runtime, e.what()); - } - }, this, filename, allow_replace_deleted); - } + return facebook::jsi::Value::undefined(); + } - // readIndexSync(filename: string, allowReplaceDeleted?: boolean): void; - JSI_HOST_FUNCTION(readIndexSync) { - JSI_ASSERT(runtime, argc >= 1, "Expected at least 1 argument"); - JSI_ASSERT(runtime, args[0].isString(), "Expected string"); - if (argc >= 2) JSI_ASSERT(runtime, args[1].isBool(), "Expected boolean"); + // resizeIndex(newMaxElements: number): void; + JSI_HOST_FUNCTION(resizeIndex) { + JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - const std::string filename = args[0].getString(runtime).utf8(runtime); - const bool allow_replace_deleted = JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 1, false).getBool(); + const uint32_t new_max_elements = args[0].getNumber(); - if (index_ != nullptr) index_.reset(); + try { + index_->resizeIndex(new_max_elements); + } catch (const std::exception &e) { + throw facebook::jsi::JSError(runtime, e.what()); + } - try { - index_ = std::make_unique>(space_.get(), filename, false, 0, allow_replace_deleted); - } catch (const std::exception &e) { - throw facebook::jsi::JSError(runtime, e.what()); - } + return facebook::jsi::Value::undefined(); + } - return facebook::jsi::Value::undefined(); + // addPoint(point: number[], label: number, replaceDeleted?: boolean): void; + JSI_HOST_FUNCTION(addPoint) { + JSI_ASSERT(runtime, argc >= 2, "Expected at least 2 arguments"); + JSI_ASSERT(runtime, args[0].isArray(), "Expected array"); + JSI_ASSERT(runtime, args[1].isNumber(), "Expected number"); + if (argc >= 3) + JSI_ASSERT(runtime, args[2].isBool(), "Expected boolean"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + + std::vector vector; + auto point = args[0].asArray(runtime); + for (size_t i = 0; i < point.length(runtime); i++) { + vector.push_back(point.getValueAtIndex(runtime, i).getNumber()); } + const uint32_t label = args[1].getNumber(); + const bool replace_deleted = + JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 2, false).getBool(); - // writeIndex(filename: string): Promise; - JSI_HOST_FUNCTION(writeIndex) { - JSI_ASSERT(runtime, argc == 1 && args[0].isString(), "Expected string"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - - const std::string filename = args[0].getString(runtime).utf8(runtime); - - return JSI_CREATE_PROMISE(runtime, { - try { - index_->saveIndex(filename); - resolve.call(runtime); - } catch (const std::exception &e) { - reject.call(runtime, e.what()); - } - }, this, filename); + if (normalize_) { + normalize_vector(vector.data()); } - // writeIndexSync(filename: string): void; - JSI_HOST_FUNCTION(writeIndexSync) { - JSI_ASSERT(runtime, argc == 1 && args[0].isString(), "Expected string"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - - const std::string filename = args[0].getString(runtime).utf8(runtime); - - try { - index_->saveIndex(filename); - } catch (const std::exception &e) { - throw facebook::jsi::JSError(runtime, e.what()); - } - - return facebook::jsi::Value::undefined(); + try { + index_->addPoint(vector.data(), label, replace_deleted); + } catch (const std::exception &e) { + throw facebook::jsi::JSError(runtime, e.what()); } - // resizeIndex(newMaxElements: number): void; - JSI_HOST_FUNCTION(resizeIndex) { - JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - - const uint32_t new_max_elements = args[0].getNumber(); + return facebook::jsi::Value::undefined(); + } - try { - index_->resizeIndex(new_max_elements); - } catch (const std::exception &e) { - throw facebook::jsi::JSError(runtime, e.what()); - } + // markDelete(label: number): void; + JSI_HOST_FUNCTION(markDelete) { + JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - return facebook::jsi::Value::undefined(); - } + const uint32_t label = args[0].getNumber(); - // addPoint(point: number[], label: number, replaceDeleted?: boolean): void; - JSI_HOST_FUNCTION(addPoint) { - JSI_ASSERT(runtime, argc >= 2, "Expected at least 2 arguments"); - JSI_ASSERT(runtime, args[0].isArray(), "Expected array"); - JSI_ASSERT(runtime, args[1].isNumber(), "Expected number"); - if (argc >= 3) JSI_ASSERT(runtime, args[2].isBool(), "Expected boolean"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - - std::vector vector; - auto point = args[0].asArray(runtime); - for (size_t i = 0; i < point.length(runtime); i++) { - vector.push_back(point.getValueAtIndex(runtime, i).getNumber()); - } - const uint32_t label = args[1].getNumber(); - const bool replace_deleted = JSI_GET_ARGUMENT_OR_DEFAULT(args, runtime, 2, false).getBool(); - - if (normalize_) { - normalize_vector(vector.data()); - } - - try { - index_->addPoint(vector.data(), label, replace_deleted); - } catch (const std::exception &e) { - throw facebook::jsi::JSError(runtime, e.what()); - } - - return facebook::jsi::Value::undefined(); + try { + index_->markDelete(label); + } catch (const std::exception &e) { + throw facebook::jsi::JSError(runtime, e.what()); } - // markDelete(label: number): void; - JSI_HOST_FUNCTION(markDelete) { - JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + return facebook::jsi::Value::undefined(); + } - const uint32_t label = args[0].getNumber(); + // unmarkDelete(label: number): void; + JSI_HOST_FUNCTION(unmarkDelete) { + JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - try { - index_->markDelete(label); - } catch (const std::exception &e) { - throw facebook::jsi::JSError(runtime, e.what()); - } + const uint32_t label = args[0].getNumber(); - return facebook::jsi::Value::undefined(); + try { + index_->unmarkDelete(label); + } catch (const std::exception &e) { + throw facebook::jsi::JSError(runtime, e.what()); } - // unmarkDelete(label: number): void; - JSI_HOST_FUNCTION(unmarkDelete) { - JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + return facebook::jsi::Value::undefined(); + } - const uint32_t label = args[0].getNumber(); + // searchKnn(queryPoint: number[], numNeighbors: number, filter?: + // FilterFunction): {distances: number[], neighbors: number[]}; + JSI_HOST_FUNCTION(searchKnn) { + JSI_ASSERT(runtime, argc >= 2, "Expected at least 2 arguments"); + JSI_ASSERT(runtime, args[0].isArray(), "Expected array"); + JSI_ASSERT(runtime, args[1].isNumber(), "Expected number"); + if (argc >= 3) + JSI_ASSERT(runtime, args[2].isObject(), "Expected object"); + + std::vector query_vector; + auto query_point = args[0].asArray(runtime); + for (size_t i = 0; i < query_point.length(runtime); i++) { + query_vector.push_back( + query_point.getValueAtIndex(runtime, i).getNumber()); + } + const uint32_t num_neighbors = args[1].getNumber(); + + hnswlib::BaseFilterFunctor *filter = nullptr; + if (argc >= 3) { + auto filter_function = args[2].getObject(runtime).asFunction(runtime); + filter = new hnswlib::BaseFilterFunctor([filter_function](int label) { + auto future = RnJsi::Context::invoke(filter_function, {label}); + return future.get().getBool(); + }); + } - try { - index_->unmarkDelete(label); - } catch (const std::exception &e) { - throw facebook::jsi::JSError(runtime, e.what()); - } + std::priority_queue> result = + index_->searchKnn(query_vector.data(), num_neighbors, filter); - return facebook::jsi::Value::undefined(); + std::vector distances; + std::vector neighbors; + while (!result.empty()) { + distances.push_back(result.top().first); + neighbors.push_back(result.top().second); + result.pop(); } - // searchKnn(queryPoint: number[], numNeighbors: number, filter?: FilterFunction): {distances: number[], neighbors: number[]}; - JSI_HOST_FUNCTION(searchKnn) { - JSI_ASSERT(runtime, argc >= 2, "Expected at least 2 arguments"); - JSI_ASSERT(runtime, args[0].isArray(), "Expected array"); - JSI_ASSERT(runtime, args[1].isNumber(), "Expected number"); - if (argc >= 3) JSI_ASSERT(runtime, args[2].isObject(), "Expected object"); - - std::vector query_vector; - auto query_point = args[0].asArray(runtime); - for (size_t i = 0; i < query_point.length(runtime); i++) { - query_vector.push_back(query_point.getValueAtIndex(runtime, i).getNumber()); - } - const uint32_t num_neighbors = args[1].getNumber(); - - hnswlib::BaseFilterFunctor *filter = nullptr; - if (argc >= 3) { - auto filter_function = args[2].getObject(runtime).asFunction(runtime); - filter = new hnswlib::BaseFilterFunctor([filter_function](int label) { - auto future = RnJsi::Context::invoke(filter_function, {label}); - return future.get().getBool(); - }); - } - - std::priority_queue> result = index_->searchKnn(query_vector.data(), num_neighbors, filter); - - std::vector distances; - std::vector neighbors; - while (!result.empty()) { - distances.push_back(result.top().first); - neighbors.push_back(result.top().second); - result.pop(); - } - - if (filter != nullptr) delete filter; - - auto resultObject = jsi::Object(runtime); - resultObject.setProperty(runtime, "distances", jsi::Array(runtime, distances)); - resultObject.setProperty(runtime, "neighbors", jsi::Array(runtime, neighbors)); - - return resultObject; - } + if (filter != nullptr) + delete filter; - // getIdsList(): number[]; - JSI_HOST_FUNCTION(getIdsList) { - JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + auto resultObject = jsi::Object(runtime); + resultObject.setProperty(runtime, "distances", + jsi::Array(runtime, distances)); + resultObject.setProperty(runtime, "neighbors", + jsi::Array(runtime, neighbors)); - std::vector ids_list = index_->getIdsList(); + return resultObject; + } - return jsi::Array(runtime, ids_list); - } + // getIdsList(): number[]; + JSI_HOST_FUNCTION(getIdsList) { + JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - // getPoint(label: number): number[]; - JSI_HOST_FUNCTION(getPoint) { - JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + std::vector ids_list = index_->getIdsList(); - const uint32_t label = args[0].getNumber(); + return jsi::Array(runtime, ids_list); + } - std::vector point = index_->getDataByLabel(static_cast(label)); + // getPoint(label: number): number[]; + JSI_HOST_FUNCTION(getPoint) { + JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - return jsi::Array(runtime, point); - } + const uint32_t label = args[0].getNumber(); - // getMaxElements(): number; - JSI_HOST_FUNCTION(getMaxElements) { - JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + std::vector point = + index_->getDataByLabel(static_cast(label)); - return index_->maxelements_; - } + return jsi::Array(runtime, point); + } - // getCurrentCount(): number; - JSI_HOST_FUNCTION(getCurrentCount) { - JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + // getMaxElements(): number; + JSI_HOST_FUNCTION(getMaxElements) { + JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - return index_->cur_element_count; - } + return index_->maxelements_; + } - // getNumDimensions(): number; - JSI_HOST_FUNCTION(getNumDimensions) { - JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + // getCurrentCount(): number; + JSI_HOST_FUNCTION(getCurrentCount) { + JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - return index_->data_size_ / sizeof(float); - } + return index_->cur_element_count; + } - // getEf(): number; - JSI_HOST_FUNCTION(getEf) { - JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + // getNumDimensions(): number; + JSI_HOST_FUNCTION(getNumDimensions) { + JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - return index_->ef_; - } + return index_->data_size_ / sizeof(float); + } - // setEf(ef: number): void; - JSI_HOST_FUNCTION(setEf) { - JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); - JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); + // getEf(): number; + JSI_HOST_FUNCTION(getEf) { + JSI_ASSERT(runtime, argc == 0, "Expected 0 arguments"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - const uint32_t ef = args[0].getNumber(); + return index_->ef_; + } - index_->ef_ = ef; + // setEf(ef: number): void; + JSI_HOST_FUNCTION(setEf) { + JSI_ASSERT(runtime, argc == 1 && args[0].isNumber(), "Expected number"); + JSI_ASSERT(runtime, index_ == nullptr, "Index is not initialized"); - return facebook::jsi::Value::undefined(); - } + const uint32_t ef = args[0].getNumber(); - JSI_HOST_CLASS_DEFINE_METHODS({ - {"initIndex", JSI_BIND_METHOD(initIndex)}, - {"readIndex", JSI_BIND_METHOD(readIndex)}, - {"readIndexSync", JSI_BIND_METHOD(readIndexSync)}, - {"writeIndex", JSI_BIND_METHOD(writeIndex)}, - {"writeIndexSync", JSI_BIND_METHOD(writeIndexSync)}, - {"resizeIndex", JSI_BIND_METHOD(resizeIndex)}, - {"addPoint", JSI_BIND_METHOD(addPoint)}, - {"markDelete", JSI_BIND_METHOD(markDelete)}, - {"unmarkDelete", JSI_BIND_METHOD(unmarkDelete)}, - {"searchKnn", JSI_BIND_METHOD(searchKnn)}, - {"getIdsList", JSI_BIND_METHOD(getIdsList)}, - {"getPoint", JSI_BIND_METHOD(getPoint)}, - {"getMaxElements", JSI_BIND_METHOD(getMaxElements)}, - {"getCurrentCount", JSI_BIND_METHOD(getCurrentCount)}, - {"getNumDimensions", JSI_BIND_METHOD(getNumDimensions)}, - {"getEf", JSI_BIND_METHOD(getEf)}, - {"setEf", JSI_BIND_METHOD(setEf)} - }) - - private: - std::unique_ptr> index_ = nullptr; - std::unique_ptr> space_ = nullptr; - bool normalize_ = false; - }; -} + index_->ef_ = ef; + + return facebook::jsi::Value::undefined(); + } + + JSI_HOST_CLASS_DEFINE_METHODS( + {{"initIndex", JSI_BIND_METHOD(initIndex)}, + {"readIndex", JSI_BIND_METHOD(readIndex)}, + {"readIndexSync", JSI_BIND_METHOD(readIndexSync)}, + {"writeIndex", JSI_BIND_METHOD(writeIndex)}, + {"writeIndexSync", JSI_BIND_METHOD(writeIndexSync)}, + {"resizeIndex", JSI_BIND_METHOD(resizeIndex)}, + {"addPoint", JSI_BIND_METHOD(addPoint)}, + {"markDelete", JSI_BIND_METHOD(markDelete)}, + {"unmarkDelete", JSI_BIND_METHOD(unmarkDelete)}, + {"searchKnn", JSI_BIND_METHOD(searchKnn)}, + {"getIdsList", JSI_BIND_METHOD(getIdsList)}, + {"getPoint", JSI_BIND_METHOD(getPoint)}, + {"getMaxElements", JSI_BIND_METHOD(getMaxElements)}, + {"getCurrentCount", JSI_BIND_METHOD(getCurrentCount)}, + {"getNumDimensions", JSI_BIND_METHOD(getNumDimensions)}, + {"getEf", JSI_BIND_METHOD(getEf)}, + {"setEf", JSI_BIND_METHOD(setEf)}}) + +private: + std::unique_ptr> index_ = nullptr; + std::unique_ptr> space_ = nullptr; + bool normalize_ = false; +}; +} // namespace RNHnswlib diff --git a/cpp/RnJsiContext.hpp b/cpp/RnJsiContext.hpp index b6434ab..eb31bec 100644 --- a/cpp/RnJsiContext.hpp +++ b/cpp/RnJsiContext.hpp @@ -1,42 +1,50 @@ #pragma once -#include -#include +#include #include +#include #include -#include +#include namespace RnJsi { class Context { public: - Context(facebook::jsi::Runtime *runtime, std::shared_ptr jsCallInvoker): runtime_(runtime), jsCallInvoker_(jsCallInvoker) {} + Context(facebook::jsi::Runtime *runtime, + std::shared_ptr jsCallInvoker) + : runtime_(runtime), jsCallInvoker_(jsCallInvoker) {} - static inline void init(facebook::jsi::Runtime *runtime, std::shared_ptr jsCallInvoker) { + static inline void + init(facebook::jsi::Runtime *runtime, + std::shared_ptr jsCallInvoker) { instance.reset(new Context(runtime, jsCallInvoker)); } - static std::future invoke(facebook::jsi::Function fn, std::initializer_list args) { + static std::future + invoke(facebook::jsi::Function fn, + std::initializer_list args) { std::promise promise; std::future future = promise.get_future(); if (instance && instance->jsCallInvoker) { - instance->jsCallInvoker->invokeAsync([&fn, &args, &promise, runtime = instance->runtime] { - try { - promise.set_value(fn.call(*runtime, args)); - } catch (...) { - promise.set_exception(std::current_exception()); - } - }); + instance->jsCallInvoker->invokeAsync( + [&fn, &args, &promise, runtime = instance->runtime] { + try { + promise.set_value(fn.call(*runtime, args)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); } else { - promise.set_exception(std::make_exception_ptr(std::runtime_error("Context not initialized"))); + promise.set_exception(std::make_exception_ptr( + std::runtime_error("Context not initialized"))); } return future; } private: - static std::shared_ptr instance { nullptr }; + static std::shared_ptr instance{nullptr}; std::shared_ptr runtime_; std::shared_ptr jsCallInvoker_; }; -} +} // namespace RnJsi diff --git a/cpp/RnJsiHostClass.hpp b/cpp/RnJsiHostClass.hpp index 2f2915f..496aac4 100644 --- a/cpp/RnJsiHostClass.hpp +++ b/cpp/RnJsiHostClass.hpp @@ -1,37 +1,39 @@ #pragma once -#include #include "RnJsiContext.hpp" +#include -#define JSI_HOST_OBJECT_CONSTRUCTOR(NAME) \ - NAME(jsi::Runtime &runtime, const jsi::Object &thisValue, const jsi::Value *args, size_t argc): RnJSI::HostClass() +#define JSI_HOST_OBJECT_CONSTRUCTOR(NAME) \ + NAME(jsi::Runtime &runtime, const jsi::Object &thisValue, \ + const jsi::Value *args, size_t argc) \ + : RnJSI::HostClass() -#define JSI_HOST_CLASS_DEFINE_METHODS(METHODS) \ +#define JSI_HOST_CLASS_DEFINE_METHODS(METHODS) \ std::unordered_map getMethods() const override { \ - return METHODS; \ + return METHODS; \ } namespace RnJsi { - template - class HostClass: public facebook::jsi::HostObject { - public: - HostClass() = default; +template class HostClass : public facebook::jsi::HostObject { +public: + HostClass() = default; - static JSI_HOST_FUNCTION(constructor) { - auto instance = std::make_shared(runtime, thisValue, args, argc); - auto prototype = jsi::Object(runtime); - for (auto const& [key, value]: instance->getMethods()) { - prototype.setProperty(runtime, key, value); - } - instance->setProperty(runtime, "prototype", prototype); - return instance; + static JSI_HOST_FUNCTION(constructor) { + auto instance = std::make_shared(runtime, thisValue, args, argc); + auto prototype = jsi::Object(runtime); + for (auto const &[key, value] : instance->getMethods()) { + prototype.setProperty(runtime, key, value); } + instance->setProperty(runtime, "prototype", prototype); + return instance; + } - static jsi::Function getConstructor(jsi::Runtime &runtime) { - return JSI_DEFINE_FUNCTION(runtime, constructor, 0, JSI_BIND_METHOD(constructor)); - } + static jsi::Function getConstructor(jsi::Runtime &runtime) { + return JSI_DEFINE_FUNCTION(runtime, constructor, 0, + JSI_BIND_METHOD(constructor)); + } - protected: - virtual std::unordered_map getMethods() const = 0; - }; -} +protected: + virtual std::unordered_map getMethods() const = 0; +}; +} // namespace RnJsi diff --git a/cpp/helper.h b/cpp/helper.h index 7e263c9..06676bc 100644 --- a/cpp/helper.h +++ b/cpp/helper.h @@ -1,47 +1,50 @@ #pragma once -#define JSI_HOST_FUNCTION(NAME) \ - facebook::jsi::Value NAME( \ - facebook::jsi::Runtime &runtime, \ - const facebook::jsi::Value &thisValue, \ - const facebook::jsi::Value *args, \ - size_t argc \ - ) - -#define JSI_DEFINE_FUNCTION(RUNTIME, NAME, ARGC, FUNCTION) \ - facebook::jsi::Function::createFromHostFunction( \ - RUNTIME, \ - facebook::jsi::PropNameID::forAscii(RUNTIME, #NAME), \ - ARGC, \ - FUNCTION \ - ) - -#define JSI_EXPOSE_FUNCTION(RUNTIME, NAME, ARGC, FUNCTION) \ - { \ - auto NAME = JSI_DEFINE_FUNCTION(RUNTIME, NAME, ARGC, FUNCTION); \ - RUNTIME.global().setProperty(RUNTIME, #NAME, std::move(NAME)); \ +#define JSI_HOST_FUNCTION(NAME) \ + facebook::jsi::Value NAME(facebook::jsi::Runtime &runtime, \ + const facebook::jsi::Value &thisValue, \ + const facebook::jsi::Value *args, size_t argc) + +#define JSI_DEFINE_FUNCTION(RUNTIME, NAME, ARGC, FUNCTION) \ + facebook::jsi::Function::createFromHostFunction( \ + RUNTIME, facebook::jsi::PropNameID::forAscii(RUNTIME, #NAME), ARGC, \ + FUNCTION) + +#define JSI_EXPOSE_FUNCTION(RUNTIME, NAME, ARGC, FUNCTION) \ + { \ + auto NAME = JSI_DEFINE_FUNCTION(RUNTIME, NAME, ARGC, FUNCTION); \ + RUNTIME.global().setProperty(RUNTIME, #NAME, std::move(NAME)); \ } -#define JSI_BIND_METHOD(METHOD) \ - std::bind(&METHOD, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4) +#define JSI_BIND_METHOD(METHOD) \ + std::bind(&METHOD, this, std::placeholders::_1, std::placeholders::_2, \ + std::placeholders::_3, std::placeholders::_4) -#define JSI_ASSERT(RUNTIME, CONDITION, MESSAGE) \ - if (!(CONDITION)) throw facebook::jsi::JSError(RUNTIME, MESSAGE) +#define JSI_ASSERT(RUNTIME, CONDITION, MESSAGE) \ + if (!(CONDITION)) \ + throw facebook::jsi::JSError(RUNTIME, MESSAGE) -#define JSI_GET_PROPERTY_OR_DEFAULT(OBJECT, RUNTIME, PROPERTY, DEFAULT) \ - OBJECT.getProperty(RUNTIME, PROPERTY).isUndefined() ? DEFAULT : OBJECT.getProperty(RUNTIME, PROPERTY) +#define JSI_GET_PROPERTY_OR_DEFAULT(OBJECT, RUNTIME, PROPERTY, DEFAULT) \ + OBJECT.getProperty(RUNTIME, PROPERTY).isUndefined() \ + ? DEFAULT \ + : OBJECT.getProperty(RUNTIME, PROPERTY) -#define JSI_GET_ARGUMENT_OR_DEFAULT(ARGS, RUNTIME, INDEX, DEFAULT) \ +#define JSI_GET_ARGUMENT_OR_DEFAULT(ARGS, RUNTIME, INDEX, DEFAULT) \ ARGS.size() <= INDEX ? DEFAULT : ARGS[INDEX] -#define JSI_CREATE_PROMISE(RUNTIME, EXECUTOR, ...) \ - RUNTIME.global().getPropertyAsFunction(RUNTIME, "Promise").callAsConstructor(RUNTIME, facebook::jsi::Function::createFromHostFunction( \ - RUNTIME, \ - facebook::jsi::PropNameID::forAscii(RUNTIME, "executor"), \ - 2, \ - [__VA_ARGS__](facebook::jsi::Runtime &runtime, const facebook::jsi::Value &thisValue, const facebook::jsi::Value *args, size_t argc) -> facebook::jsi::Value { \ - auto resolve = args[0].getObject(runtime).asFunction(runtime); \ - auto reject = args[1].getObject(runtime).asFunction(runtime); \ - EXECUTOR \ - } \ - )) +#define JSI_CREATE_PROMISE(RUNTIME, EXECUTOR, ...) \ + RUNTIME.global() \ + .getPropertyAsFunction(RUNTIME, "Promise") \ + .callAsConstructor( \ + RUNTIME, \ + facebook::jsi::Function::createFromHostFunction( \ + RUNTIME, \ + facebook::jsi::PropNameID::forAscii(RUNTIME, "executor"), 2, \ + [__VA_ARGS__](facebook::jsi::Runtime &runtime, \ + const facebook::jsi::Value &thisValue, \ + const facebook::jsi::Value *args, \ + size_t argc) -> facebook::jsi::Value { \ + auto resolve = args[0].getObject(runtime).asFunction(runtime); \ + auto reject = args[1].getObject(runtime).asFunction(runtime); \ + EXECUTOR \ + })) diff --git a/cpp/react-native-hnswlib.cpp b/cpp/react-native-hnswlib.cpp index bd55f26..11f5952 100644 --- a/cpp/react-native-hnswlib.cpp +++ b/cpp/react-native-hnswlib.cpp @@ -3,10 +3,12 @@ #include "RnJsiContext.hpp" namespace RNHnswlib { - void install(jsi::Runtime *runtime, std::shared_ptr jsCallInvoker) { - RnJsi::Context::init(runtime, jsCallInvoker); - auto hnswlib = jsi::Object(*runtime); - hnswlib.setProperty(*runtime, "HierarchicalNSW", HierarchicalNSW::getConstructor(runtime)); - runtime->global().setProperty(runtime, "hnswlib", hnswlib); - } +void install(jsi::Runtime *runtime, + std::shared_ptr jsCallInvoker) { + RnJsi::Context::init(runtime, jsCallInvoker); + auto hnswlib = jsi::Object(*runtime); + hnswlib.setProperty(*runtime, "HierarchicalNSW", + HierarchicalNSW::getConstructor(runtime)); + runtime->global().setProperty(runtime, "hnswlib", hnswlib); } +} // namespace RNHnswlib diff --git a/cpp/react-native-hnswlib.h b/cpp/react-native-hnswlib.h index 5de70bd..9c2ba59 100644 --- a/cpp/react-native-hnswlib.h +++ b/cpp/react-native-hnswlib.h @@ -1,11 +1,12 @@ -#include #include +#include #ifndef HNSWLIB_H #define HNSWLIB_H namespace RNHnswlib { - void install(jsi::Runtime &runtime, std::shared_ptr jsCallInvoker); +void install(jsi::Runtime &runtime, + std::shared_ptr jsCallInvoker); } #endif /* HNSWLIB_H */