From 7319174285dcc4b58c7c98a0fb29db23de21e9bc Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Fri, 20 Dec 2024 00:34:05 +0000 Subject: [PATCH] Update vendored DuckDB sources to a9bf1a6c --- .../extension/parquet/parquet_extension.cpp | 5 +- .../common/vector_operations/vector_hash.cpp | 4 +- .../csv_scanner/sniffer/csv_sniffer.cpp | 2 +- .../csv_scanner/sniffer/dialect_detection.cpp | 19 +- .../function/table/version/pragma_version.cpp | 6 +- .../window/window_boundaries_state.cpp | 6 + .../src/function/window/window_executor.cpp | 4 + .../window/window_merge_sort_tree.cpp | 31 +++- .../function/window/window_rank_function.cpp | 4 - .../window/window_rownumber_function.cpp | 163 +++++++++++++++++- .../function/window/window_value_function.cpp | 54 ------ .../duckdb/execution/merge_sort_tree.hpp | 20 +-- .../function/window/window_executor.hpp | 3 + .../window/window_merge_sort_tree.hpp | 2 +- .../function/window/window_rank_function.hpp | 3 - .../window/window_rownumber_function.hpp | 17 ++ .../function/window/window_token_tree.hpp | 8 +- .../function/window/window_value_function.hpp | 11 -- .../expression/transform_function.cpp | 4 +- 19 files changed, 253 insertions(+), 113 deletions(-) diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index 1c9c09ec..8cbea809 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -797,15 +797,18 @@ class ParquetScanFunction { auto &gstate = data_p.global_state->Cast(); auto &bind_data = data_p.bind_data->CastNoConst(); + bool rowgroup_finished; do { if (gstate.CanRemoveColumns()) { data.all_columns.Reset(); data.reader->Scan(data.scan_state, data.all_columns); + rowgroup_finished = data.all_columns.size() == 0; bind_data.multi_file_reader->FinalizeChunk(context, bind_data.reader_bind, data.reader->reader_data, data.all_columns, gstate.multi_file_reader_state); output.ReferenceColumns(data.all_columns, gstate.projection_ids); } else { data.reader->Scan(data.scan_state, output); + rowgroup_finished = output.size() == 0; bind_data.multi_file_reader->FinalizeChunk(context, bind_data.reader_bind, data.reader->reader_data, output, gstate.multi_file_reader_state); } @@ -814,7 +817,7 @@ class ParquetScanFunction { if (output.size() > 0) { return; } - if (!ParquetParallelStateNext(context, bind_data, data, gstate)) { + if (rowgroup_finished && !ParquetParallelStateNext(context, bind_data, data, gstate)) { return; } } while (true); diff --git a/src/duckdb/src/common/vector_operations/vector_hash.cpp b/src/duckdb/src/common/vector_operations/vector_hash.cpp index e6ef5f5f..c82422c2 100644 --- a/src/duckdb/src/common/vector_operations/vector_hash.cpp +++ b/src/duckdb/src/common/vector_operations/vector_hash.cpp @@ -21,7 +21,9 @@ struct HashOp { }; static inline hash_t CombineHashScalar(hash_t a, hash_t b) { - return (a * UINT64_C(0xbf58476d1ce4e5b9)) ^ b; + a ^= a >> 32; + a *= 0xd6e8feb86659fd93U; + return a ^ b; } template diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp index 74928a5c..7d6e2e3e 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp @@ -6,7 +6,7 @@ namespace duckdb { CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, shared_ptr buffer_manager_p, CSVStateMachineCache &state_machine_cache_p, bool default_null_to_varchar_p) : state_machine_cache(state_machine_cache_p), options(options_p), buffer_manager(std::move(buffer_manager_p)), - default_null_to_varchar(default_null_to_varchar_p) { + lines_sniffed(0), default_null_to_varchar(default_null_to_varchar_p) { // Initialize Format Candidates for (const auto &format_template : format_template_candidates) { auto &logical_type = format_template.first; diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp index ae06d024..14099df8 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp @@ -80,11 +80,11 @@ string DialectCandidates::Print() { DialectCandidates::DialectCandidates(const CSVStateMachineOptions &options) { // assert that quotes escapes and rules have equal size - auto default_quote = GetDefaultQuote(); - auto default_escape = GetDefaultEscape(); - auto default_quote_rule = GetDefaultQuoteRule(); - auto default_delimiter = GetDefaultDelimiter(); - auto default_comment = GetDefaultComment(); + const auto default_quote = GetDefaultQuote(); + const auto default_escape = GetDefaultEscape(); + const auto default_quote_rule = GetDefaultQuoteRule(); + const auto default_delimiter = GetDefaultDelimiter(); + const auto default_comment = GetDefaultComment(); D_ASSERT(default_quote.size() == default_quote_rule.size() && default_quote_rule.size() == default_escape.size()); // fill the escapes @@ -187,6 +187,9 @@ void CSVSniffer::GenerateStateMachineSearchSpace(vector scanner, idx_t num_cols = sniffed_column_counts.result_position == 0 ? 1 : sniffed_column_counts[0].number_of_columns; const bool ignore_errors = options.ignore_errors.GetValue(); // If we are ignoring errors and not null_padding , we pick the most frequent number of columns as the right one - bool use_most_frequent_columns = ignore_errors && !options.null_padding; + const bool use_most_frequent_columns = ignore_errors && !options.null_padding; if (use_most_frequent_columns) { num_cols = sniffed_column_counts.GetMostFrequentColumnCount(); } @@ -355,7 +358,7 @@ void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, // - There's a single column before. // - There are more values and no additional padding is required. // - There's more than one column and less padding is required. - if (columns_match_set && rows_consistent && + if (columns_match_set && (rows_consistent || (set_columns.IsSet() && ignore_errors)) && (single_column_before || ((more_values || more_columns) && !require_more_padding) || (more_than_one_column && require_less_padding) || quoted) && !invalid_padding && comments_are_acceptable) { diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index e7e0a610..53720a79 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "4-dev3722" +#define DUCKDB_PATCH_VERSION "4-dev3741" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 1 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.1.4-dev3722" +#define DUCKDB_VERSION "v1.1.4-dev3741" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "62582045a3" +#define DUCKDB_SOURCE_ID "ab8c909857" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/function/window/window_boundaries_state.cpp b/src/duckdb/src/function/window/window_boundaries_state.cpp index 52fe91e2..92f14860 100644 --- a/src/duckdb/src/function/window/window_boundaries_state.cpp +++ b/src/duckdb/src/function/window/window_boundaries_state.cpp @@ -302,6 +302,10 @@ WindowBoundsSet WindowBoundariesState::GetWindowBounds(const BoundWindowExpressi switch (wexpr.GetExpressionType()) { case ExpressionType::WINDOW_ROW_NUMBER: result.insert(PARTITION_BEGIN); + if (!wexpr.arg_orders.empty()) { + // Secondary orders need to know how wide the partition is + result.insert(PARTITION_END); + } break; case ExpressionType::WINDOW_RANK_DENSE: case ExpressionType::WINDOW_RANK: @@ -309,6 +313,7 @@ WindowBoundsSet WindowBoundariesState::GetWindowBounds(const BoundWindowExpressi if (wexpr.arg_orders.empty()) { result.insert(PEER_BEGIN); } else { + // Secondary orders need to know how wide the partition is result.insert(PARTITION_END); } break; @@ -316,6 +321,7 @@ WindowBoundsSet WindowBoundariesState::GetWindowBounds(const BoundWindowExpressi result.insert(PARTITION_BEGIN); result.insert(PARTITION_END); if (wexpr.arg_orders.empty()) { + // Secondary orders need to know where the first peer is result.insert(PEER_BEGIN); } break; diff --git a/src/duckdb/src/function/window/window_executor.cpp b/src/duckdb/src/function/window/window_executor.cpp index d8f4286d..40381f0e 100644 --- a/src/duckdb/src/function/window/window_executor.cpp +++ b/src/duckdb/src/function/window/window_executor.cpp @@ -41,6 +41,10 @@ WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, ClientContext &cont boundary_start_idx = shared.RegisterEvaluate(wexpr.start_expr); boundary_end_idx = shared.RegisterEvaluate(wexpr.end_expr); + + for (const auto &order : wexpr.arg_orders) { + arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); + } } WindowExecutorGlobalState::WindowExecutorGlobalState(const WindowExecutor &executor, const idx_t payload_count, diff --git a/src/duckdb/src/function/window/window_merge_sort_tree.cpp b/src/duckdb/src/function/window/window_merge_sort_tree.cpp index c1e7b11e..b2d7db4d 100644 --- a/src/duckdb/src/function/window/window_merge_sort_tree.cpp +++ b/src/duckdb/src/function/window/window_merge_sort_tree.cpp @@ -1,4 +1,5 @@ #include "duckdb/function/window/window_merge_sort_tree.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" #include #include @@ -6,7 +7,7 @@ namespace duckdb { WindowMergeSortTree::WindowMergeSortTree(ClientContext &context, const vector &orders, - const vector &sort_idx, const idx_t count) + const vector &sort_idx, const idx_t count, bool unique) : context(context), memory_per_thread(PhysicalOperator::GetMaxThreadMemory(context)), sort_idx(sort_idx), build_stage(PartitionSortStage::INIT), tasks_completed(0) { // Sort the unfiltered indices by the orders @@ -26,7 +27,19 @@ WindowMergeSortTree::WindowMergeSortTree(ClientContext &context, const vector(buffer_manager, orders, payload_layout); + if (unique) { + vector unique_orders; + for (const auto &order : orders) { + unique_orders.emplace_back(order.Copy()); + } + auto unique_expr = make_uniq(Value(index_type)); + const auto order_type = OrderType::ASCENDING; + const auto order_by_type = OrderByNullType::NULLS_LAST; + unique_orders.emplace_back(BoundOrderByNode(order_type, order_by_type, std::move(unique_expr))); + global_sort = make_uniq(buffer_manager, unique_orders, payload_layout); + } else { + global_sort = make_uniq(buffer_manager, orders, payload_layout); + } global_sort->external = ClientConfig::GetConfig(context).force_external; } @@ -48,18 +61,22 @@ WindowMergeSortTreeLocalState::WindowMergeSortTreeLocalState(WindowMergeSortTree void WindowMergeSortTreeLocalState::SinkChunk(DataChunk &chunk, const idx_t row_idx, optional_ptr filter_sel, idx_t filtered) { + // Sequence the payload column + auto &indices = payload_chunk.data[0]; + payload_chunk.SetCardinality(chunk); + indices.Sequence(int64_t(row_idx), 1, payload_chunk.size()); + // Reference the sort columns auto &sort_idx = window_tree.sort_idx; for (column_t c = 0; c < sort_idx.size(); ++c) { sort_chunk.data[c].Reference(chunk.data[sort_idx[c]]); } + // Add the row numbers if we are uniquifying + if (sort_idx.size() < sort_chunk.ColumnCount()) { + sort_chunk.data[sort_idx.size()].Reference(indices); + } sort_chunk.SetCardinality(chunk); - // Sequence the payload column - auto &indices = payload_chunk.data[0]; - payload_chunk.SetCardinality(sort_chunk); - indices.Sequence(int64_t(row_idx), 1, payload_chunk.size()); - // Apply FILTER clause, if any if (filter_sel) { sort_chunk.Slice(*filter_sel, filtered); diff --git a/src/duckdb/src/function/window/window_rank_function.cpp b/src/duckdb/src/function/window/window_rank_function.cpp index c5dcbc63..05a39094 100644 --- a/src/duckdb/src/function/window/window_rank_function.cpp +++ b/src/duckdb/src/function/window/window_rank_function.cpp @@ -93,10 +93,6 @@ void WindowPeerLocalState::NextRank(idx_t partition_begin, idx_t peer_begin, idx WindowPeerExecutor::WindowPeerExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared) : WindowExecutor(wexpr, context, shared) { - - for (const auto &order : wexpr.arg_orders) { - arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); - } } unique_ptr WindowPeerExecutor::GetGlobalState(const idx_t payload_count, diff --git a/src/duckdb/src/function/window/window_rownumber_function.cpp b/src/duckdb/src/function/window/window_rownumber_function.cpp index d87b0fae..71f27fa3 100644 --- a/src/duckdb/src/function/window/window_rownumber_function.cpp +++ b/src/duckdb/src/function/window/window_rownumber_function.cpp @@ -1,7 +1,78 @@ #include "duckdb/function/window/window_rownumber_function.hpp" +#include "duckdb/function/window/window_shared_expressions.hpp" +#include "duckdb/function/window/window_token_tree.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" namespace duckdb { +//===--------------------------------------------------------------------===// +// WindowRowNumberGlobalState +//===--------------------------------------------------------------------===// +class WindowRowNumberGlobalState : public WindowExecutorGlobalState { +public: + WindowRowNumberGlobalState(const WindowRowNumberExecutor &executor, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) + : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask), + ntile_idx(executor.ntile_idx) { + if (!executor.arg_order_idx.empty()) { + // "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their position in + // the input data, such that two elements never compare as equal." + token_tree = make_uniq(executor.context, executor.wexpr.arg_orders, executor.arg_order_idx, + payload_count, true); + } + } + + //! The token tree for ORDER BY arguments + unique_ptr token_tree; + + //! The evaluation index for NTILE + const column_t ntile_idx; +}; + +//===--------------------------------------------------------------------===// +// WindowRowNumberLocalState +//===--------------------------------------------------------------------===// +class WindowRowNumberLocalState : public WindowExecutorBoundsState { +public: + explicit WindowRowNumberLocalState(const WindowRowNumberGlobalState &grstate) + : WindowExecutorBoundsState(grstate), grstate(grstate) { + if (grstate.token_tree) { + local_tree = grstate.token_tree->GetLocalState(); + } + } + + //! Accumulate the secondary sort values + void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx) override; + //! Finish the sinking and prepare to scan + void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; + + //! The corresponding global peer state + const WindowRowNumberGlobalState &grstate; + //! The optional sorting state for secondary sorts + unique_ptr local_tree; +}; + +void WindowRowNumberLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx) { + WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx); + + if (local_tree) { + auto &local_tokens = local_tree->Cast(); + local_tokens.SinkChunk(sink_chunk, input_idx, nullptr, 0); + } +} + +void WindowRowNumberLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { + WindowExecutorBoundsState::Finalize(gstate, collection); + + if (local_tree) { + auto &local_tokens = local_tree->Cast(); + local_tokens.Sort(); + local_tokens.window_tree.Build(); + } +} + //===--------------------------------------------------------------------===// // WindowRowNumberExecutor //===--------------------------------------------------------------------===// @@ -10,14 +81,100 @@ WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, C : WindowExecutor(wexpr, context, shared) { } +unique_ptr WindowRowNumberExecutor::GetGlobalState(const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { + return make_uniq(*this, payload_count, partition_mask, order_mask); +} + +unique_ptr +WindowRowNumberExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { + return make_uniq(gstate.Cast()); +} + void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { - auto &lbstate = lstate.Cast(); - auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); + auto &grstate = gstate.Cast(); + auto &lrstate = lstate.Cast(); + auto partition_begin = FlatVector::GetData(lrstate.bounds.data[PARTITION_BEGIN]); + auto rdata = FlatVector::GetData(result); + + if (grstate.token_tree) { + auto partition_end = FlatVector::GetData(lrstate.bounds.data[PARTITION_END]); + for (idx_t i = 0; i < count; ++i, ++row_idx) { + // Row numbers are unique ranks + rdata[i] = grstate.token_tree->Rank(partition_begin[i], partition_end[i], row_idx); + } + return; + } + + for (idx_t i = 0; i < count; ++i, ++row_idx) { + rdata[i] = row_idx - partition_begin[i] + 1; + } +} + +//===--------------------------------------------------------------------===// +// WindowNtileExecutor +//===--------------------------------------------------------------------===// +WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, + WindowSharedExpressions &shared) + : WindowRowNumberExecutor(wexpr, context, shared) { + + // NTILE has one argument + ntile_idx = shared.RegisterEvaluate(wexpr.children[0]); +} + +void WindowNtileExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { + auto &grstate = gstate.Cast(); + auto &lrstate = lstate.Cast(); + auto partition_begin = FlatVector::GetData(lrstate.bounds.data[PARTITION_BEGIN]); + auto partition_end = FlatVector::GetData(lrstate.bounds.data[PARTITION_END]); auto rdata = FlatVector::GetData(result); + WindowInputExpression ntile_col(eval_chunk, ntile_idx); for (idx_t i = 0; i < count; ++i, ++row_idx) { - rdata[i] = NumericCast(row_idx - partition_begin[i] + 1); + if (ntile_col.CellIsNull(i)) { + FlatVector::SetNull(result, i, true); + } else { + auto n_param = ntile_col.GetCell(i); + if (n_param < 1) { + throw InvalidInputException("Argument for ntile must be greater than zero"); + } + // With thanks from SQLite's ntileValueFunc() + auto n_total = NumericCast(partition_end[i] - partition_begin[i]); + if (n_param > n_total) { + // more groups allowed than we have values + // map every entry to a unique group + n_param = n_total; + } + int64_t n_size = (n_total / n_param); + // find the row idx within the group + D_ASSERT(row_idx >= partition_begin[i]); + idx_t partition_idx = 0; + if (grstate.token_tree) { + partition_idx = grstate.token_tree->Rank(partition_begin[i], partition_end[i], row_idx) - 1; + } else { + partition_idx = row_idx - partition_begin[i]; + } + auto adjusted_row_idx = NumericCast(partition_idx); + + // now compute the ntile + int64_t n_large = n_total - n_param * n_size; + int64_t i_small = n_large * (n_size + 1); + int64_t result_ntile; + + D_ASSERT((n_large * (n_size + 1) + (n_param - n_large) * n_size) == n_total); + + if (adjusted_row_idx < i_small) { + result_ntile = 1 + adjusted_row_idx / (n_size + 1); + } else { + result_ntile = 1 + n_large + (adjusted_row_idx - i_small) / n_size; + } + // result has to be between [1, NTILE] + D_ASSERT(result_ntile >= 1 && result_ntile <= n_param); + rdata[i] = result_ntile; + } } } diff --git a/src/duckdb/src/function/window/window_value_function.cpp b/src/duckdb/src/function/window/window_value_function.cpp index 3f261845..ce776b7d 100644 --- a/src/duckdb/src/function/window/window_value_function.cpp +++ b/src/duckdb/src/function/window/window_value_function.cpp @@ -144,15 +144,6 @@ WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, ClientCon offset_idx = shared.RegisterEvaluate(wexpr.offset_expr); default_idx = shared.RegisterEvaluate(wexpr.default_expr); - - for (const auto &order : wexpr.arg_orders) { - arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); - } -} - -WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { } unique_ptr WindowValueExecutor::GetGlobalState(const idx_t payload_count, @@ -174,51 +165,6 @@ unique_ptr WindowValueExecutor::GetLocalState(const Wi return make_uniq(gvstate); } -void WindowNtileExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { - auto &lvstate = lstate.Cast(); - auto &cursor = *lvstate.cursor; - auto partition_begin = FlatVector::GetData(lvstate.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(lvstate.bounds.data[PARTITION_END]); - auto rdata = FlatVector::GetData(result); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - if (cursor.CellIsNull(0, row_idx)) { - FlatVector::SetNull(result, i, true); - } else { - auto n_param = cursor.GetCell(0, row_idx); - if (n_param < 1) { - throw InvalidInputException("Argument for ntile must be greater than zero"); - } - // With thanks from SQLite's ntileValueFunc() - auto n_total = NumericCast(partition_end[i] - partition_begin[i]); - if (n_param > n_total) { - // more groups allowed than we have values - // map every entry to a unique group - n_param = n_total; - } - int64_t n_size = (n_total / n_param); - // find the row idx within the group - D_ASSERT(row_idx >= partition_begin[i]); - auto adjusted_row_idx = NumericCast(row_idx - partition_begin[i]); - // now compute the ntile - int64_t n_large = n_total - n_param * n_size; - int64_t i_small = n_large * (n_size + 1); - int64_t result_ntile; - - D_ASSERT((n_large * (n_size + 1) + (n_param - n_large) * n_size) == n_total); - - if (adjusted_row_idx < i_small) { - result_ntile = 1 + adjusted_row_idx / (n_size + 1); - } else { - result_ntile = 1 + n_large + (adjusted_row_idx - i_small) / n_size; - } - // result has to be between [1, NTILE] - D_ASSERT(result_ntile >= 1 && result_ntile <= n_param); - rdata[i] = result_ntile; - } - } -} - //===--------------------------------------------------------------------===// // WindowLeadLagLocalState //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp b/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp index 5d37f66f..672aaa56 100644 --- a/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp +++ b/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp @@ -125,7 +125,7 @@ struct MergeSortTree { } template - void AggregateLowerBound(const idx_t lower, const idx_t upper, const idx_t needle, L aggregate) const; + void AggregateLowerBound(const idx_t lower, const idx_t upper, const E needle, L aggregate) const; Tree tree; CompareElements cmp; @@ -571,7 +571,7 @@ idx_t MergeSortTree::SelectNth(const SubFrames &frames, idx_t n template template -void MergeSortTree::AggregateLowerBound(const idx_t lower, const idx_t upper, const idx_t needle, +void MergeSortTree::AggregateLowerBound(const idx_t lower, const idx_t upper, const E needle, L aggregate) const { if (lower >= upper) { @@ -638,7 +638,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons // Search based on cascading info from previous level const auto *search_begin = level_data + cascading_idcs[cascading_idx.first]; const auto *search_end = level_data + cascading_idcs[cascading_idx.first + FANOUT]; - const auto run_pos = std::lower_bound(search_begin, search_end, needle) - level_data; + const auto run_pos = std::lower_bound(search_begin, search_end, needle, cmp.cmp) - level_data; // Compute runBegin and pass it to our callback const auto run_begin = curr.first - level_width; aggregate(level, run_begin, NumericCast(run_pos)); @@ -650,7 +650,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons if (curr.first != lower) { const auto *search_begin = level_data + cascading_idcs[cascading_idx.first]; const auto *search_end = level_data + cascading_idcs[cascading_idx.first + FANOUT]; - auto idx = NumericCast(std::lower_bound(search_begin, search_end, needle) - level_data); + auto idx = NumericCast(std::lower_bound(search_begin, search_end, needle, cmp.cmp) - level_data); cascading_idx.first = (idx / CASCADING + 2 * (lower / level_width)) * FANOUT; } @@ -660,7 +660,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons // Search based on cascading info from previous level const auto *search_begin = level_data + cascading_idcs[cascading_idx.second]; const auto *search_end = level_data + cascading_idcs[cascading_idx.second + FANOUT]; - const auto run_pos = std::lower_bound(search_begin, search_end, needle) - level_data; + const auto run_pos = std::lower_bound(search_begin, search_end, needle, cmp.cmp) - level_data; // Compute runBegin and pass it to our callback const auto run_begin = curr.second; aggregate(level, run_begin, NumericCast(run_pos)); @@ -672,7 +672,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons if (curr.second != upper) { const auto *search_begin = level_data + cascading_idcs[cascading_idx.second]; const auto *search_end = level_data + cascading_idcs[cascading_idx.second + FANOUT]; - auto idx = NumericCast(std::lower_bound(search_begin, search_end, needle) - level_data); + auto idx = NumericCast(std::lower_bound(search_begin, search_end, needle, cmp.cmp) - level_data); cascading_idx.second = (idx / CASCADING + 2 * (upper / level_width)) * FANOUT; } } while (level >= LowestCascadingLevel()); @@ -688,7 +688,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons const auto *search_end = level_data + curr.first; const auto *search_begin = search_end - level_width; const auto run_pos = - NumericCast(std::lower_bound(search_begin, search_end, needle) - level_data); + NumericCast(std::lower_bound(search_begin, search_end, needle, cmp.cmp) - level_data); const auto run_begin = NumericCast(search_begin - level_data); aggregate(level, run_begin, run_pos); curr.first -= level_width; @@ -698,7 +698,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons const auto *search_begin = level_data + curr.second; const auto *search_end = search_begin + level_width; const auto run_pos = - NumericCast(std::lower_bound(search_begin, search_end, needle) - level_data); + NumericCast(std::lower_bound(search_begin, search_end, needle, cmp.cmp) - level_data); const auto run_begin = NumericCast(search_begin - level_data); aggregate(level, run_begin, run_pos); curr.second += level_width; @@ -714,7 +714,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons while (lower_it != curr.first) { const auto *search_begin = level_data + lower_it; const auto run_begin = lower_it; - const auto run_pos = run_begin + (*search_begin < needle); + const auto run_pos = run_begin + cmp.cmp(*search_begin, needle); aggregate(level, run_begin, run_pos); ++lower_it; } @@ -722,7 +722,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons while (curr.second != upper) { const auto *search_begin = level_data + curr.second; const auto run_begin = curr.second; - const auto run_pos = run_begin + (*search_begin < needle); + const auto run_pos = run_begin + cmp.cmp(*search_begin, needle); aggregate(level, run_begin, run_pos); ++curr.second; } diff --git a/src/duckdb/src/include/duckdb/function/window/window_executor.hpp b/src/duckdb/src/include/duckdb/function/window/window_executor.hpp index 44d3a905..0c5e6354 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_executor.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_executor.hpp @@ -114,6 +114,9 @@ class WindowExecutor { optional_ptr range_expr; column_t range_idx = DConstants::INVALID_INDEX; + //! The column indices of any ORDER BY argument expressions + vector arg_order_idx; + protected: virtual void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const = 0; diff --git a/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp index 9b9a16f1..6faecbbf 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp @@ -53,7 +53,7 @@ class WindowMergeSortTree { using LocalSortStatePtr = unique_ptr; WindowMergeSortTree(ClientContext &context, const vector &orders, - const vector &sort_idx, const idx_t count); + const vector &sort_idx, const idx_t count, bool unique = false); virtual ~WindowMergeSortTree() = default; virtual unique_ptr GetLocalState() = 0; diff --git a/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp index 4c84d1a4..e1a0cf48 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp @@ -19,9 +19,6 @@ class WindowPeerExecutor : public WindowExecutor { unique_ptr GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const override; unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const override; - - //! The column indices of any ORDER BY argument expressions - vector arg_order_idx; }; class WindowRankExecutor : public WindowPeerExecutor { diff --git a/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp index 46b118e9..7ee0979c 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp @@ -16,6 +16,23 @@ class WindowRowNumberExecutor : public WindowExecutor { public: WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + unique_ptr GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) const override; + unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const override; + + //! The evaluation index of the NTILE column + column_t ntile_idx = DConstants::INVALID_INDEX; + +protected: + void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, + Vector &result, idx_t count, idx_t row_idx) const override; +}; + +// NTILE is just scaled ROW_NUMBER +class WindowNtileExecutor : public WindowRowNumberExecutor { +public: + WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + protected: void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const override; diff --git a/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp index 8f69a41d..b5501255 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp @@ -16,12 +16,12 @@ namespace duckdb { class WindowTokenTree : public WindowMergeSortTree { public: WindowTokenTree(ClientContext &context, const vector &orders, const vector &sort_idx, - const idx_t count) - : WindowMergeSortTree(context, orders, sort_idx, count) { + const idx_t count, bool unique = false) + : WindowMergeSortTree(context, orders, sort_idx, count, unique) { } WindowTokenTree(ClientContext &context, const BoundOrderModifier &order_bys, const vector &sort_idx, - const idx_t count) - : WindowTokenTree(context, order_bys.orders, sort_idx, count) { + const idx_t count, bool unique = false) + : WindowTokenTree(context, order_bys.orders, sort_idx, count, unique) { } unique_ptr GetLocalState() override; diff --git a/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp index 42d2cb19..bc0457ce 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp @@ -32,19 +32,8 @@ class WindowValueExecutor : public WindowExecutor { column_t offset_idx = DConstants::INVALID_INDEX; //! The column index of the default value column column_t default_idx = DConstants::INVALID_INDEX; - //! The column indices of the argument ORDER BY expressions - vector arg_order_idx; }; -// -class WindowNtileExecutor : public WindowValueExecutor { -public: - WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); - -protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; -}; class WindowLeadLagExecutor : public WindowValueExecutor { public: WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); diff --git a/src/duckdb/src/parser/transform/expression/transform_function.cpp b/src/duckdb/src/parser/transform/expression/transform_function.cpp index d9a52ce1..fd12b6ce 100644 --- a/src/duckdb/src/parser/transform/expression/transform_function.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_function.cpp @@ -117,14 +117,14 @@ static bool IsOrderableWindowFunction(ExpressionType type) { case ExpressionType::WINDOW_NTH_VALUE: case ExpressionType::WINDOW_RANK: case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_ROW_NUMBER: + case ExpressionType::WINDOW_NTILE: return true; case ExpressionType::WINDOW_LEAD: case ExpressionType::WINDOW_LAG: case ExpressionType::WINDOW_AGGREGATE: - case ExpressionType::WINDOW_ROW_NUMBER: case ExpressionType::WINDOW_RANK_DENSE: case ExpressionType::WINDOW_CUME_DIST: - case ExpressionType::WINDOW_NTILE: return false; default: throw InternalException("Unknown orderable window type %s", ExpressionTypeToString(type).c_str());