Skip to content

Commit

Permalink
Improve code & handle warnings (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Dec 25, 2024
1 parent 0ad6b5b commit 475c34a
Show file tree
Hide file tree
Showing 16 changed files with 102 additions and 94 deletions.
8 changes: 4 additions & 4 deletions docs/docs/aggregate-reduce-functions/llm-rerank.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Rerank documents based on their relevance to a given query:
```sql
SELECT llm_rerank(
{'model_name': 'gpt-4'},
{'prompt': 'Rank these documents based on their relevance to the search query using the document title and content.'},
{'prompt': 'Rank documents by title keywords (AI, emerging tech), content relevance (innovative approaches), recency, and credibility.'},
{'document_title': document_title, 'document_content': document_content}
) AS reranked_documents
FROM documents;
Expand All @@ -50,7 +50,7 @@ Rerank documents for each category based on their relevance:
SELECT category,
llm_rerank(
{'model_name': 'gpt-4'},
{'prompt': 'Rank these documents based on their relevance to the search query using the document title and content.'},
{'prompt': 'Rank documents by title keywords (AI, emerging tech), content relevance (innovative approaches), recency, and credibility.'},
{'document_title': document_title, 'document_content': document_content}
) AS reranked_documents
FROM documents
Expand Down Expand Up @@ -86,7 +86,7 @@ WITH ranked_documents AS (
)
SELECT llm_rerank(
{'model_name': 'gpt-4'},
{'prompt': 'Rank these documents based on their relevance to the search query using the document title and content.'},
{'prompt': 'Rank documents by title keywords (AI, emerging tech), content relevance (innovative approaches), recency, and credibility.'},
{'document_title': document_title, 'document_content': document_content}
) AS reranked_documents
FROM ranked_documents;
Expand Down Expand Up @@ -125,7 +125,7 @@ Two types of prompts can be used:
- Directly provides the prompt in the query.
- **Example**:
```sql
{'prompt': 'Rank these documents based on their relevance to the search query using the document title and content.'}
{'prompt': 'Rank documents by title keywords (AI, emerging tech), content relevance (innovative approaches), recency, and credibility.'}
```

2. **Named Prompt**
Expand Down
10 changes: 5 additions & 5 deletions src/core/config/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ void Config::SetupDefaultModelsConfig(duckdb::Connection& con, std::string& sche
con.Query(duckdb_fmt::format(
" INSERT INTO {}.{} (model_name, model, provider_name, model_args) "
" VALUES "
" ('default', 'gpt-4o-mini', 'openai', '{{\"context_window\": 128000, \"max_output_tokens\": 16384}}'),"
" ('gpt-4o-mini', 'gpt-4o-mini', 'openai', '{{\"context_window\": 128000, \"max_output_tokens\": 16384}}'),"
" ('gpt-4o', 'gpt-4o', 'openai', '{{\"context_window\": 128000, \"max_output_tokens\": 16384}}'),"
" ('default', 'gpt-4o-mini', 'openai', '{{\"context_window\":128000,\"max_output_tokens\":16384}}'),"
" ('gpt-4o-mini', 'gpt-4o-mini', 'openai', '{{\"context_window\":128000,\"max_output_tokens\":16384}}'),"
" ('gpt-4o', 'gpt-4o', 'openai', '{{\"context_window\":128000,\"max_output_tokens\":16384}}'),"
" ('text-embedding-3-large', 'text-embedding-3-large', 'openai', "
" '{{\"context_window\": {}, \"max_output_tokens\": {}}}'),"
" '{{\"context_window\":{},\"max_output_tokens\":{}}}'),"
" ('text-embedding-3-small', 'text-embedding-3-small', 'openai', "
" '{{\"context_window\": {}, \"max_output_tokens\": {}}}')",
" '{{\"context_window\":{},\"max_output_tokens\":{}}}')",
schema_name, table_name, Config::default_context_window, Config::default_max_output_tokens,
Config::default_context_window, Config::default_max_output_tokens));
}
Expand Down
15 changes: 8 additions & 7 deletions src/custom_parser/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Tokenizer::Tokenizer(const std::string& query) : query_(query), position_(0) {}

// Skip whitespace
void Tokenizer::SkipWhitespace() {
while (position_ < query_.size() && std::isspace(query_[position_])) {
while (position_ < static_cast<int>(query_.size()) && std::isspace(query_[position_])) {
++position_;
}
}
Expand All @@ -25,10 +25,10 @@ Token Tokenizer::ParseStringLiteral() {
}
++position_;
int start = position_;
while (position_ < query_.size() && query_[position_] != '\'') {
while (position_ < static_cast<int>(query_.size()) && query_[position_] != '\'') {
++position_;
}
if (position_ == query_.size()) {
if (position_ == static_cast<int>(query_.size())) {
throw std::runtime_error("Unterminated string literal.");
}
std::string value = query_.substr(start, position_ - start);
Expand All @@ -42,7 +42,7 @@ Token Tokenizer::ParseJson() {
}
auto start = position_++;
auto brace_count = 1;
while (position_ < query_.size() && brace_count > 0) {
while (position_ < static_cast<int>(query_.size()) && brace_count > 0) {
if (query_[position_] == '{') {
++brace_count;
} else if (query_[position_] == '}') {
Expand All @@ -60,7 +60,8 @@ Token Tokenizer::ParseJson() {
// Parse a keyword (word made of letters)
Token Tokenizer::ParseKeyword() {
auto start = position_;
while (position_ < query_.size() && (std::isalpha(query_[position_]) || query_[position_] == '_')) {
while (position_ < static_cast<int>(query_.size()) &&
(std::isalpha(query_[position_]) || query_[position_] == '_')) {
++position_;
}
auto value = query_.substr(start, position_ - start);
Expand All @@ -77,7 +78,7 @@ Token Tokenizer::ParseSymbol() {
// Parse a number (sequence of digits)
Token Tokenizer::ParseNumber() {
auto start = position_;
while (position_ < query_.size() && std::isdigit(query_[position_])) {
while (position_ < static_cast<int>(query_.size()) && std::isdigit(query_[position_])) {
++position_;
}
auto value = query_.substr(start, position_ - start);
Expand All @@ -94,7 +95,7 @@ Token Tokenizer::ParseParenthesis() {
// Get the next token from the input
Token Tokenizer::GetNextToken() {
SkipWhitespace();
if (position_ >= query_.size()) {
if (position_ >= static_cast<int>(query_.size())) {
return {TokenType::END_OF_FILE, ""};
}

Expand Down
12 changes: 7 additions & 5 deletions src/functions/aggregate/llm_first_or_last/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) {
accumulated_tuples_tokens = Tiktoken::GetNumTokens(batch_tuples.dump());
accumulated_tuples_tokens +=
Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownHeader(tuples[start_index]));
while (accumulated_tuples_tokens < available_tokens && start_index < tuples.size()) {
auto num_tokens = Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index]));
if (accumulated_tuples_tokens + num_tokens > available_tokens) {
while (accumulated_tuples_tokens < static_cast<unsigned int>(available_tokens) &&
start_index < static_cast<int>(tuples.size())) {
const auto num_tokens =
Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index]));
if (accumulated_tuples_tokens + num_tokens > static_cast<unsigned int>(available_tokens)) {
break;
}
batch_tuples.push_back(tuples[start_index]);
Expand All @@ -45,7 +47,7 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) {
auto result_idx = GetFirstOrLastTupleId(batch_tuples);
batch_tuples.clear();
batch_tuples.push_back(tuples[result_idx]);
} while (start_index < tuples.size());
} while (start_index < static_cast<int>(tuples.size()));
batch_tuples[0].erase("flockmtl_tuple_id");

return batch_tuples[0];
Expand All @@ -62,7 +64,7 @@ void LlmFirstOrLast::FinalizeResults(duckdb::Vector& states, duckdb::AggregateIn
auto state_ptr = states_vector[idx];
auto state = function_instance->state_map[state_ptr];
auto tuples_with_ids = nlohmann::json::array();
for (auto j = 0; j < state->value.size(); j++) {
for (auto j = 0; j < static_cast<int>(state->value.size()); j++) {
auto tuple_with_id = state->value[j];
tuple_with_id["flockmtl_tuple_id"] = j;
tuples_with_ids.push_back(tuple_with_id);
Expand Down
10 changes: 6 additions & 4 deletions src/functions/aggregate/llm_reduce/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector<nlohmann::json>& tuples)
accumulated_tuples_tokens = Tiktoken::GetNumTokens(batch_tuples.dump());
accumulated_tuples_tokens +=
Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownHeader(tuples[start_index]));
while (accumulated_tuples_tokens < available_tokens && start_index < tuples.size()) {
auto num_tokens = Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index]));
if (accumulated_tuples_tokens + num_tokens > available_tokens) {
while (accumulated_tuples_tokens < static_cast<unsigned int>(available_tokens) &&
start_index < static_cast<int>(tuples.size())) {
const auto num_tokens =
Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index]));
if (accumulated_tuples_tokens + num_tokens > static_cast<unsigned int>(available_tokens)) {
break;
}
batch_tuples.push_back(tuples[start_index]);
Expand All @@ -47,7 +49,7 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector<nlohmann::json>& tuples)
batch_tuples.clear();
batch_tuples.push_back(response);
accumulated_tuples_tokens = 0u;
} while (start_index < tuples.size());
} while (start_index < static_cast<int>(tuples.size()));

return batch_tuples[0];
}
Expand Down
6 changes: 3 additions & 3 deletions src/functions/aggregate/llm_rerank/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) {
accumulated_rows_tokens += Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownHeader(tuples[start_index]));
while (available_tokens - accumulated_rows_tokens > 0 && start_index >= 0) {
auto num_tokens = Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index]));
if (accumulated_rows_tokens + num_tokens > available_tokens) {
if (accumulated_rows_tokens + num_tokens > static_cast<unsigned int>(available_tokens)) {
break;
}
window_tuples.push_back(tuples[start_index]);
Expand All @@ -53,7 +53,7 @@ nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) {
}

auto indexed_tuples = nlohmann::json::array();
for (auto i = 0; i < window_tuples.size(); i++) {
for (auto i = 0; i < static_cast<int>(window_tuples.size()); i++) {
auto indexed_tuple = window_tuples[i];
indexed_tuple["flockmtl_tuple_id"] = i;
indexed_tuples.push_back(indexed_tuple);
Expand Down Expand Up @@ -81,7 +81,7 @@ void LlmRerank::Finalize(duckdb::Vector& states, duckdb::AggregateInputData& agg
auto state = function_instance->state_map[state_ptr];

auto tuples_with_ids = nlohmann::json::array();
for (auto j = 0; j < state->value.size(); j++) {
for (auto j = 0; j < static_cast<int>(state->value.size()); j++) {
tuples_with_ids.push_back(state->value[j]);
}
auto reranked_tuples = function_instance->SlidingWindow(tuples_with_ids);
Expand Down
4 changes: 2 additions & 2 deletions src/functions/batch_response_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ std::vector<nlohmann::json> CastVectorOfStructsToJson(const duckdb::Vector& stru
std::vector<nlohmann::json> vector_json;
for (auto i = 0; i < size; i++) {
nlohmann::json json;
for (auto j = 0; j < duckdb::StructType::GetChildCount(struct_vector.GetType()); j++) {
auto key = duckdb::StructType::GetChildName(struct_vector.GetType(), j);
for (auto j = 0; j < static_cast<int>(duckdb::StructType::GetChildCount(struct_vector.GetType())); j++) {
const auto key = duckdb::StructType::GetChildName(struct_vector.GetType(), j);
auto value = duckdb::StructValue::GetChildren(struct_vector.GetValue(i))[j].ToString();
json[key] = value;
}
Expand Down
6 changes: 3 additions & 3 deletions src/functions/scalar/fusion_relative/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace flockmtl {

void FusionRelative::ValidateArguments(duckdb::DataChunk& args) {
for (int i = 0; i < args.ColumnCount(); i++) {
for (int i = 0; i < static_cast<int>(args.ColumnCount()); i++) {
if (args.data[i].GetType() != duckdb::LogicalType::DOUBLE) {
throw std::runtime_error("fusion_relative: argument must be a double");
}
Expand All @@ -14,9 +14,9 @@ std::vector<double> FusionRelative::Operation(duckdb::DataChunk& args) {
FusionRelative::ValidateArguments(args);

std::vector<double> results;
for (auto i = 0; i < args.size(); i++) {
for (auto i = 0; i < static_cast<int>(args.size()); i++) {
auto max = 0.0;
for (auto j = 0; j < args.ColumnCount(); j++) {
for (auto j = 0; j < static_cast<int>(args.ColumnCount()); j++) {
auto valueWrapper = args.data[j].GetValue(i);
if (!valueWrapper.IsNull()) {
auto value = valueWrapper.GetValue<double>();
Expand Down
1 change: 0 additions & 1 deletion src/functions/scalar/llm_filter/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ std::vector<std::string> LlmFilter::Operation(duckdb::DataChunk& args) {

auto responses = BatchAndComplete(tuples, prompt_details.prompt, ScalarFunctionType::FILTER, model);

auto index = 0;
std::vector<std::string> results;
results.reserve(responses.size());
for (const auto& response : responses) {
Expand Down
10 changes: 5 additions & 5 deletions src/functions/scalar/scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ nlohmann::json ScalarFunctionBase::BatchAndComplete(const std::vector<nlohmann::
do {
accumulated_tuples_tokens +=
Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownHeader(tuples[start_index]));
while (accumulated_tuples_tokens < available_tokens && start_index < tuples.size() &&
batch_tuples.size() < batch_size) {
auto num_tokens =
while (accumulated_tuples_tokens < static_cast<unsigned int>(available_tokens) &&
start_index < static_cast<int>(tuples.size()) && batch_tuples.size() < batch_size) {
const auto num_tokens =
Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index]));
if (accumulated_tuples_tokens + num_tokens > available_tokens) {
if (accumulated_tuples_tokens + num_tokens > static_cast<unsigned int>(available_tokens)) {
break;
}
batch_tuples.push_back(tuples[start_index]);
Expand Down Expand Up @@ -67,7 +67,7 @@ nlohmann::json ScalarFunctionBase::BatchAndComplete(const std::vector<nlohmann::
responses.push_back(tuple);
}

} while (start_index < tuples.size());
} while (start_index < static_cast<int>(tuples.size()));
}

return responses;
Expand Down
8 changes: 4 additions & 4 deletions src/include/flockmtl/functions/aggregate/aggregate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ class AggregateFunctionBase {

template <class Derived>
static void Combine(duckdb::Vector& source, duckdb::Vector& target, duckdb::AggregateInputData& aggr_input_data,
idx_t count) {
auto source_vector = duckdb::FlatVector::GetData<AggregateFunctionState*>(source);
auto target_vector = duckdb::FlatVector::GetData<AggregateFunctionState*>(target);
const idx_t count) {
const auto source_vector = duckdb::FlatVector::GetData<AggregateFunctionState*>(source);
const auto target_vector = duckdb::FlatVector::GetData<AggregateFunctionState*>(target);

auto function_instance = GetInstance<Derived>();
for (auto i = 0; i < count; i++) {
for (auto i = 0; i < static_cast<int>(count); i++) {
auto source_ptr = source_vector[i];
auto target_ptr = target_vector[i];

Expand Down
12 changes: 6 additions & 6 deletions src/include/flockmtl/model_manager/providers/handlers/ollama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ namespace flockmtl {

class OllamaModelManager {
public:
OllamaModelManager(const std::string& url, bool throw_exception)
: _session("Ollama", throw_exception), _url(url), _throw_exception(throw_exception) {}
OllamaModelManager(const std::string& url, const bool throw_exception)
: _session("Ollama", throw_exception), _throw_exception(throw_exception), _url(url) {}
OllamaModelManager(const OllamaModelManager&) = delete;
OllamaModelManager& operator=(const OllamaModelManager&) = delete;
OllamaModelManager(OllamaModelManager&&) = delete;
OllamaModelManager& operator=(OllamaModelManager&&) = delete;

std::string GetChatUrl() { return _url + "/api/generate"; }
std::string GetChatUrl() const { return _url + "/api/generate"; }

std::string GetEmbedUrl() { return _url + "/api/embeddings"; }
std::string GetEmbedUrl() const { return _url + "/api/embeddings"; }

std::string GetAvailableOllamaModelsUrl() {
static int check_done = -1;
Expand All @@ -40,13 +40,13 @@ class OllamaModelManager {
}

nlohmann::json CallComplete(const nlohmann::json& json, const std::string& contentType = "application/json") {
std::string url = GetChatUrl();
const std::string url = GetChatUrl();
_session.setUrl(url);
return execute_post(json.dump(), contentType);
}

nlohmann::json CallEmbedding(const nlohmann::json& json, const std::string& contentType = "application/json") {
std::string url = GetEmbedUrl();
const std::string url = GetEmbedUrl();
_session.setUrl(url);
return execute_post(json.dump(), contentType);
}
Expand Down
10 changes: 5 additions & 5 deletions src/include/flockmtl/secret_manager/secret_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ class SecretDetails {
std::vector<std::string> required_fields;
};

extern const SecretDetails openai_secret_details;
extern const SecretDetails azure_secret_details;
extern const SecretDetails ollama_secret_details;
SecretDetails get_openai_secret_details();
SecretDetails get_azure_secret_details();
SecretDetails get_ollama_secret_details();

class SecretManager {
public:
enum SupportedProviders { OPENAI, AZURE, OLLAMA };
static std::unordered_map<std::string, SupportedProviders> providerNames;

static void Register(duckdb::DatabaseInstance& instance);
static std::unordered_map<std::string, std::string> GetSecret(std::string provider);
static SupportedProviders GetProviderType(std::string provider);
static std::unordered_map<std::string, std::string> GetSecret(const std::string& secret_name);
static SupportedProviders GetProviderType(const std::string& provider);

private:
static void RegisterSecretType(duckdb::DatabaseInstance& instance);
Expand Down
2 changes: 1 addition & 1 deletion src/model_manager/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::tuple<std::string, std::string, int32_t, int32_t> Model::GetQueriedModel(co
}

void Model::ConstructProvider() {
switch (auto provider = GetProviderType(model_details_.provider_name)) {
switch (GetProviderType(model_details_.provider_name)) {
case FLOCKMTL_OPENAI:
provider_ = std::make_shared<OpenAIProvider>(model_details_);
break;
Expand Down
2 changes: 1 addition & 1 deletion src/prompt_manager/prompt_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::string PromptManager::ConstructMarkdownHeader(const nlohmann::json& tuple)
header_markdown += key.key() + " | ";
}
header_markdown += "\n";
for (auto i = 0; i < tuple.size(); i++) {
for (auto i = 0; i < static_cast<int>(tuple.size()); i++) {
header_markdown += "|---";
}
header_markdown += "|\n";
Expand Down
Loading

0 comments on commit 475c34a

Please sign in to comment.